Merge branch 'master' into FP8-Fast-Fix

This commit is contained in:
Soulreaver90 2026-02-25 07:19:59 -05:00 committed by GitHub
commit 56e9de4f9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
188 changed files with 6882 additions and 720 deletions

127
.coderabbit.yaml Normal file
View File

@ -0,0 +1,127 @@
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
language: "en-US"
early_access: false
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
reviews:
profile: "chill"
request_changes_workflow: false
high_level_summary: false
poem: false
review_status: false
review_details: false
commit_status: true
collapse_walkthrough: true
changed_files_summary: false
sequence_diagrams: false
estimate_code_review_effort: false
assess_linked_issues: false
related_issues: false
related_prs: false
suggested_labels: false
auto_apply_labels: false
suggested_reviewers: false
auto_assign_reviewers: false
in_progress_fortune: false
enable_prompt_for_ai_agents: true
path_filters:
- "!comfy_api_nodes/apis/**"
- "!**/generated/*.pyi"
- "!.ci/**"
- "!script_examples/**"
- "!**/__pycache__/**"
- "!**/*.ipynb"
- "!**/*.png"
- "!**/*.bat"
path_instructions:
- path: "**"
instructions: |
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
de-indented, or reformatted without logic changes. If code appears in the diff
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
treat it as unchanged. Contributors should not feel obligated to address
pre-existing issues outside the scope of their contribution.
- path: "comfy/**"
instructions: |
Core ML/diffusion engine. Focus on:
- Backward compatibility (breaking changes affect all custom nodes)
- Memory management and GPU resource handling
- Performance implications in hot paths
- Thread safety for concurrent execution
- path: "comfy_api_nodes/**"
instructions: |
Third-party API integration nodes. Focus on:
- No hardcoded API keys or secrets
- Proper error handling for API failures (timeouts, rate limits, auth errors)
- Correct Pydantic model usage
- Security of user data passed to external APIs
- path: "comfy_extras/**"
instructions: |
Community-contributed extra nodes. Focus on:
- Consistency with node patterns (INPUT_TYPES, RETURN_TYPES, FUNCTION, CATEGORY)
- No breaking changes to existing node interfaces
- path: "comfy_execution/**"
instructions: |
Execution engine (graph execution, caching, jobs). Focus on:
- Caching correctness
- Concurrent execution safety
- Graph validation edge cases
- path: "nodes.py"
instructions: |
Core node definitions (2500+ lines). Focus on:
- Backward compatibility of NODE_CLASS_MAPPINGS
- Consistency of INPUT_TYPES return format
- path: "alembic_db/**"
instructions: |
Database migrations. Focus on:
- Migration safety and rollback support
- Data preservation during schema changes
auto_review:
enabled: true
auto_incremental_review: true
drafts: false
ignore_title_keywords:
- "WIP"
- "DO NOT REVIEW"
- "DO NOT MERGE"
finishing_touches:
docstrings:
enabled: false
unit_tests:
enabled: false
tools:
ruff:
enabled: false
pylint:
enabled: false
flake8:
enabled: false
gitleaks:
enabled: true
shellcheck:
enabled: false
markdownlint:
enabled: false
yamllint:
enabled: false
languagetool:
enabled: false
github-checks:
enabled: true
timeout_ms: 90000
ast-grep:
essential_rules: true
chat:
auto_reply: true
knowledge_base:
opt_out: false
learnings:
scope: "auto"

2
.gitignore vendored
View File

@ -11,7 +11,7 @@ extra_model_paths.yaml
/.vs /.vs
.vscode/ .vscode/
.idea/ .idea/
venv/ venv*/
.venv/ .venv/
/web/extensions/* /web/extensions/*
!/web/extensions/logging.js.example !/web/extensions/logging.js.example

View File

@ -227,11 +227,11 @@ Put your VAE in: models/vae
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4``` ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements: This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.2```
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only. ### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.

107
app/node_replace_manager.py Normal file
View File

@ -0,0 +1,107 @@
from __future__ import annotations
from aiohttp import web
from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
from comfy_api.latest._io_public import NodeReplace
from comfy_execution.graph_utils import is_link
import nodes
class NodeStruct(TypedDict):
inputs: dict[str, str | int | float | bool | tuple[str, int]]
class_type: str
_meta: dict[str, str]
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
new_node_struct = node_struct.copy()
if empty_inputs:
new_node_struct["inputs"] = {}
else:
new_node_struct["inputs"] = node_struct["inputs"].copy()
new_node_struct["_meta"] = node_struct["_meta"].copy()
return new_node_struct
class NodeReplaceManager:
"""Manages node replacement registrations."""
def __init__(self):
self._replacements: dict[str, list[NodeReplace]] = {}
def register(self, node_replace: NodeReplace):
"""Register a node replacement mapping."""
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
"""Get replacements for an old node ID."""
return self._replacements.get(old_node_id)
def has_replacement(self, old_node_id: str) -> bool:
"""Check if a replacement exists for an old node ID."""
return old_node_id in self._replacements
def apply_replacements(self, prompt: dict[str, NodeStruct]):
connections: dict[str, list[tuple[str, str, int]]] = {}
need_replacement: set[str] = set()
for node_number, node_struct in prompt.items():
if "class_type" not in node_struct or "inputs" not in node_struct:
continue
class_type = node_struct["class_type"]
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
need_replacement.add(node_number)
# keep track of connections
for input_id, input_value in node_struct["inputs"].items():
if is_link(input_value):
conn_number = input_value[0]
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
for node_number in need_replacement:
node_struct = prompt[node_number]
class_type = node_struct["class_type"]
replacements = self.get_replacement(class_type)
if replacements is None:
continue
# just use the first replacement
replacement = replacements[0]
new_node_id = replacement.new_node_id
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
continue
# first, replace node id (class_type)
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
new_node_struct["class_type"] = new_node_id
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
# second, replace inputs
if replacement.input_mapping is not None:
for input_map in replacement.input_mapping:
if "set_value" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
elif "old_id" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
# finalize input replacement
prompt[node_number] = new_node_struct
# third, replace outputs
if replacement.output_mapping is not None:
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
if node_number in connections:
for conns in connections[node_number]:
conn_node_number, conn_input_id, old_output_idx = conns
for output_map in replacement.output_mapping:
if output_map["old_idx"] == old_output_idx:
new_output_idx = output_map["new_idx"]
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
previous_input[1] = new_output_idx
def as_dict(self):
"""Serialize all replacements to dict."""
return {
k: [v.as_dict() for v in v_list]
for k, v_list in self._replacements.items()
}
def add_routes(self, routes):
@routes.get("/node_replacements")
async def get_node_replacements(request):
return web.json_response(self.as_dict())

View File

@ -53,7 +53,7 @@ class SubgraphManager:
return entry_id, entry return entry_id, entry
async def load_entry_data(self, entry: SubgraphEntry): async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r') as f: with open(entry['path'], 'r', encoding='utf-8') as f:
entry['data'] = f.read() entry['data'] = f.read()
return entry return entry

View File

@ -0,0 +1,44 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0; // Brightness slider -100..100
uniform float u_float1; // Contrast slider -100..100
in vec2 v_texCoord;
out vec4 fragColor;
const float MID_GRAY = 0.18; // 18% reflectance
// sRGB gamma 2.2 approximation
vec3 srgbToLinear(vec3 c) {
return pow(max(c, 0.0), vec3(2.2));
}
vec3 linearToSrgb(vec3 c) {
return pow(max(c, 0.0), vec3(1.0/2.2));
}
float mapBrightness(float b) {
return clamp(b / 100.0, -1.0, 1.0);
}
float mapContrast(float c) {
return clamp(c / 100.0 + 1.0, 0.0, 2.0);
}
void main() {
vec4 orig = texture(u_image0, v_texCoord);
float brightness = mapBrightness(u_float0);
float contrast = mapContrast(u_float1);
vec3 lin = srgbToLinear(orig.rgb);
lin = (lin - MID_GRAY) * contrast + brightness + MID_GRAY;
// Convert back to sRGB
vec3 result = linearToSrgb(clamp(lin, 0.0, 1.0));
fragColor = vec4(result, orig.a);
}

View File

@ -0,0 +1,72 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Mode
uniform float u_float0; // Amount (0 to 100)
in vec2 v_texCoord;
out vec4 fragColor;
const int MODE_LINEAR = 0;
const int MODE_RADIAL = 1;
const int MODE_BARREL = 2;
const int MODE_SWIRL = 3;
const int MODE_DIAGONAL = 4;
const float AMOUNT_SCALE = 0.0005;
const float RADIAL_MULT = 4.0;
const float BARREL_MULT = 8.0;
const float INV_SQRT2 = 0.70710678118;
void main() {
vec2 uv = v_texCoord;
vec4 original = texture(u_image0, uv);
float amount = u_float0 * AMOUNT_SCALE;
if (amount < 0.000001) {
fragColor = original;
return;
}
// Aspect-corrected coordinates for circular effects
float aspect = u_resolution.x / u_resolution.y;
vec2 centered = uv - 0.5;
vec2 corrected = vec2(centered.x * aspect, centered.y);
float r = length(corrected);
vec2 dir = r > 0.0001 ? corrected / r : vec2(0.0);
vec2 offset = vec2(0.0);
if (u_int0 == MODE_LINEAR) {
// Horizontal shift (no aspect correction needed)
offset = vec2(amount, 0.0);
}
else if (u_int0 == MODE_RADIAL) {
// Outward from center, stronger at edges
offset = dir * r * amount * RADIAL_MULT;
offset.x /= aspect; // Convert back to UV space
}
else if (u_int0 == MODE_BARREL) {
// Lens distortion simulation (r² falloff)
offset = dir * r * r * amount * BARREL_MULT;
offset.x /= aspect; // Convert back to UV space
}
else if (u_int0 == MODE_SWIRL) {
// Perpendicular to radial (rotational aberration)
vec2 perp = vec2(-dir.y, dir.x);
offset = perp * r * amount * RADIAL_MULT;
offset.x /= aspect; // Convert back to UV space
}
else if (u_int0 == MODE_DIAGONAL) {
// 45° offset (no aspect correction needed)
offset = vec2(amount, amount) * INV_SQRT2;
}
float red = texture(u_image0, uv + offset).r;
float green = original.g;
float blue = texture(u_image0, uv - offset).b;
fragColor = vec4(red, green, blue, original.a);
}

View File

@ -0,0 +1,78 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0; // temperature (-100 to 100)
uniform float u_float1; // tint (-100 to 100)
uniform float u_float2; // vibrance (-100 to 100)
uniform float u_float3; // saturation (-100 to 100)
in vec2 v_texCoord;
out vec4 fragColor;
const float INPUT_SCALE = 0.01;
const float TEMP_TINT_PRIMARY = 0.3;
const float TEMP_TINT_SECONDARY = 0.15;
const float VIBRANCE_BOOST = 2.0;
const float SATURATION_BOOST = 2.0;
const float SKIN_PROTECTION = 0.5;
const float EPSILON = 0.001;
const vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);
void main() {
vec4 tex = texture(u_image0, v_texCoord);
vec3 color = tex.rgb;
// Scale inputs: -100/100 → -1/1
float temperature = u_float0 * INPUT_SCALE;
float tint = u_float1 * INPUT_SCALE;
float vibrance = u_float2 * INPUT_SCALE;
float saturation = u_float3 * INPUT_SCALE;
// Temperature (warm/cool): positive = warm, negative = cool
color.r += temperature * TEMP_TINT_PRIMARY;
color.b -= temperature * TEMP_TINT_PRIMARY;
// Tint (green/magenta): positive = green, negative = magenta
color.g += tint * TEMP_TINT_PRIMARY;
color.r -= tint * TEMP_TINT_SECONDARY;
color.b -= tint * TEMP_TINT_SECONDARY;
// Single clamp after temperature/tint
color = clamp(color, 0.0, 1.0);
// Vibrance with skin protection
if (vibrance != 0.0) {
float maxC = max(color.r, max(color.g, color.b));
float minC = min(color.r, min(color.g, color.b));
float sat = maxC - minC;
float gray = dot(color, LUMA_WEIGHTS);
if (vibrance < 0.0) {
// Desaturate: -100 → gray
color = mix(vec3(gray), color, 1.0 + vibrance);
} else {
// Boost less saturated colors more
float vibranceAmt = vibrance * (1.0 - sat);
// Branchless skin tone protection
float isWarmTone = step(color.b, color.g) * step(color.g, color.r);
float warmth = (color.r - color.b) / max(maxC, EPSILON);
float skinTone = isWarmTone * warmth * sat * (1.0 - sat);
vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);
color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);
}
}
// Saturation
if (saturation != 0.0) {
float gray = dot(color, LUMA_WEIGHTS);
float satMix = saturation < 0.0
? 1.0 + saturation // -100 → gray
: 1.0 + saturation * SATURATION_BOOST; // +100 → 3x boost
color = mix(vec3(gray), color, satMix);
}
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
}

View File

@ -0,0 +1,94 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0; // Blur radius (020, default ~5)
uniform float u_float1; // Edge threshold (0100, default ~30)
uniform int u_int0; // Step size (0/1 = every pixel, 2+ = skip pixels)
in vec2 v_texCoord;
out vec4 fragColor;
const int MAX_RADIUS = 20;
const float EPSILON = 0.0001;
// Perceptual luminance
float getLuminance(vec3 rgb) {
return dot(rgb, vec3(0.299, 0.587, 0.114));
}
vec4 bilateralFilter(vec2 uv, vec2 texelSize, int radius,
float sigmaSpatial, float sigmaColor)
{
vec4 center = texture(u_image0, uv);
vec3 centerRGB = center.rgb;
float invSpatial2 = -0.5 / (sigmaSpatial * sigmaSpatial);
float invColor2 = -0.5 / (sigmaColor * sigmaColor + EPSILON);
vec3 sumRGB = vec3(0.0);
float sumWeight = 0.0;
int step = max(u_int0, 1);
float radius2 = float(radius * radius);
for (int dy = -MAX_RADIUS; dy <= MAX_RADIUS; dy++) {
if (dy < -radius || dy > radius) continue;
if (abs(dy) % step != 0) continue;
for (int dx = -MAX_RADIUS; dx <= MAX_RADIUS; dx++) {
if (dx < -radius || dx > radius) continue;
if (abs(dx) % step != 0) continue;
vec2 offset = vec2(float(dx), float(dy));
float dist2 = dot(offset, offset);
if (dist2 > radius2) continue;
vec3 sampleRGB = texture(u_image0, uv + offset * texelSize).rgb;
// Spatial Gaussian
float spatialWeight = exp(dist2 * invSpatial2);
// Perceptual color distance (weighted RGB)
vec3 diff = sampleRGB - centerRGB;
float colorDist = dot(diff * diff, vec3(0.299, 0.587, 0.114));
float colorWeight = exp(colorDist * invColor2);
float w = spatialWeight * colorWeight;
sumRGB += sampleRGB * w;
sumWeight += w;
}
}
vec3 resultRGB = sumRGB / max(sumWeight, EPSILON);
return vec4(resultRGB, center.a); // preserve center alpha
}
void main() {
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
float radiusF = clamp(u_float0, 0.0, float(MAX_RADIUS));
int radius = int(radiusF + 0.5);
if (radius == 0) {
fragColor = texture(u_image0, v_texCoord);
return;
}
// Edge threshold → color sigma
// Squared curve for better low-end control
float t = clamp(u_float1, 0.0, 100.0) / 100.0;
t *= t;
float sigmaColor = mix(0.01, 0.5, t);
// Spatial sigma tied to radius
float sigmaSpatial = max(radiusF * 0.75, 0.5);
fragColor = bilateralFilter(
v_texCoord,
texelSize,
radius,
sigmaSpatial,
sigmaColor
);
}

View File

@ -0,0 +1,124 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // grain amount [0.0 1.0] typical: 0.20.8
uniform float u_float1; // grain size [0.3 3.0] lower = finer grain
uniform float u_float2; // color amount [0.0 1.0] 0 = monochrome, 1 = RGB grain
uniform float u_float3; // luminance bias [0.0 1.0] 0 = uniform, 1 = shadows only
uniform int u_int0; // noise mode [0 or 1] 0 = smooth, 1 = grainy
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
// High-quality integer hash (pcg-like)
uint pcg(uint v) {
uint state = v * 747796405u + 2891336453u;
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
return (word >> 22u) ^ word;
}
// 2D -> 1D hash input
uint hash2d(uvec2 p) {
return pcg(p.x + pcg(p.y));
}
// Hash to float [0, 1]
float hashf(uvec2 p) {
return float(hash2d(p)) / float(0xffffffffu);
}
// Hash to float with offset (for RGB channels)
float hashf(uvec2 p, uint offset) {
return float(pcg(hash2d(p) + offset)) / float(0xffffffffu);
}
// Convert uniform [0,1] to roughly Gaussian distribution
// Using simple approximation: average of multiple samples
float toGaussian(uvec2 p) {
float sum = hashf(p, 0u) + hashf(p, 1u) + hashf(p, 2u) + hashf(p, 3u);
return (sum - 2.0) * 0.7; // Centered, scaled
}
float toGaussian(uvec2 p, uint offset) {
float sum = hashf(p, offset) + hashf(p, offset + 1u)
+ hashf(p, offset + 2u) + hashf(p, offset + 3u);
return (sum - 2.0) * 0.7;
}
// Smooth noise with better interpolation
float smoothNoise(vec2 p) {
vec2 i = floor(p);
vec2 f = fract(p);
// Quintic interpolation (less banding than cubic)
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
uvec2 ui = uvec2(i);
float a = toGaussian(ui);
float b = toGaussian(ui + uvec2(1u, 0u));
float c = toGaussian(ui + uvec2(0u, 1u));
float d = toGaussian(ui + uvec2(1u, 1u));
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
}
float smoothNoise(vec2 p, uint offset) {
vec2 i = floor(p);
vec2 f = fract(p);
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
uvec2 ui = uvec2(i);
float a = toGaussian(ui, offset);
float b = toGaussian(ui + uvec2(1u, 0u), offset);
float c = toGaussian(ui + uvec2(0u, 1u), offset);
float d = toGaussian(ui + uvec2(1u, 1u), offset);
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
}
void main() {
vec4 color = texture(u_image0, v_texCoord);
// Luminance (Rec.709)
float luma = dot(color.rgb, vec3(0.2126, 0.7152, 0.0722));
// Grain UV (resolution-independent)
vec2 grainUV = v_texCoord * u_resolution / max(u_float1, 0.01);
uvec2 grainPixel = uvec2(grainUV);
float g;
vec3 grainRGB;
if (u_int0 == 1) {
// Grainy mode: pure hash noise (no interpolation = no banding)
g = toGaussian(grainPixel);
grainRGB = vec3(
toGaussian(grainPixel, 100u),
toGaussian(grainPixel, 200u),
toGaussian(grainPixel, 300u)
);
} else {
// Smooth mode: interpolated with quintic curve
g = smoothNoise(grainUV);
grainRGB = vec3(
smoothNoise(grainUV, 100u),
smoothNoise(grainUV, 200u),
smoothNoise(grainUV, 300u)
);
}
// Luminance weighting (less grain in highlights)
float lumWeight = mix(1.0, 1.0 - luma, clamp(u_float3, 0.0, 1.0));
// Strength
float strength = u_float0 * 0.15;
// Color vs monochrome grain
vec3 grainColor = mix(vec3(g), grainRGB, clamp(u_float2, 0.0, 1.0));
color.rgb += grainColor * strength * lumWeight;
fragColor0 = vec4(clamp(color.rgb, 0.0, 1.0), color.a);
}

View File

@ -0,0 +1,133 @@
#version 300 es
precision mediump float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Blend mode
uniform int u_int1; // Color tint
uniform float u_float0; // Intensity
uniform float u_float1; // Radius
uniform float u_float2; // Threshold
in vec2 v_texCoord;
out vec4 fragColor;
const int BLEND_ADD = 0;
const int BLEND_SCREEN = 1;
const int BLEND_SOFT = 2;
const int BLEND_OVERLAY = 3;
const int BLEND_LIGHTEN = 4;
const float GOLDEN_ANGLE = 2.39996323;
const int MAX_SAMPLES = 48;
const vec3 LUMA = vec3(0.299, 0.587, 0.114);
float hash(vec2 p) {
p = fract(p * vec2(123.34, 456.21));
p += dot(p, p + 45.32);
return fract(p.x * p.y);
}
vec3 hexToRgb(int h) {
return vec3(
float((h >> 16) & 255),
float((h >> 8) & 255),
float(h & 255)
) * (1.0 / 255.0);
}
vec3 blend(vec3 base, vec3 glow, int mode) {
if (mode == BLEND_SCREEN) {
return 1.0 - (1.0 - base) * (1.0 - glow);
}
if (mode == BLEND_SOFT) {
return mix(
base - (1.0 - 2.0 * glow) * base * (1.0 - base),
base + (2.0 * glow - 1.0) * (sqrt(base) - base),
step(0.5, glow)
);
}
if (mode == BLEND_OVERLAY) {
return mix(
2.0 * base * glow,
1.0 - 2.0 * (1.0 - base) * (1.0 - glow),
step(0.5, base)
);
}
if (mode == BLEND_LIGHTEN) {
return max(base, glow);
}
return base + glow;
}
void main() {
vec4 original = texture(u_image0, v_texCoord);
float intensity = u_float0 * 0.05;
float radius = u_float1 * u_float1 * 0.012;
if (intensity < 0.001 || radius < 0.1) {
fragColor = original;
return;
}
float threshold = 1.0 - u_float2 * 0.01;
float t0 = threshold - 0.15;
float t1 = threshold + 0.15;
vec2 texelSize = 1.0 / u_resolution;
float radius2 = radius * radius;
float sampleScale = clamp(radius * 0.75, 0.35, 1.0);
int samples = int(float(MAX_SAMPLES) * sampleScale);
float noise = hash(gl_FragCoord.xy);
float angleOffset = noise * GOLDEN_ANGLE;
float radiusJitter = 0.85 + noise * 0.3;
float ca = cos(GOLDEN_ANGLE);
float sa = sin(GOLDEN_ANGLE);
vec2 dir = vec2(cos(angleOffset), sin(angleOffset));
vec3 glow = vec3(0.0);
float totalWeight = 0.0;
// Center tap
float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));
glow += original.rgb * centerMask * 2.0;
totalWeight += 2.0;
for (int i = 1; i < MAX_SAMPLES; i++) {
if (i >= samples) break;
float fi = float(i);
float dist = sqrt(fi / float(samples)) * radius * radiusJitter;
vec2 offset = dir * dist * texelSize;
vec3 c = texture(u_image0, v_texCoord + offset).rgb;
float mask = smoothstep(t0, t1, dot(c, LUMA));
float w = 1.0 - (dist * dist) / (radius2 * 1.5);
w = max(w, 0.0);
w *= w;
glow += c * mask * w;
totalWeight += w;
dir = vec2(
dir.x * ca - dir.y * sa,
dir.x * sa + dir.y * ca
);
}
glow *= intensity / max(totalWeight, 0.001);
if (u_int1 > 0) {
glow *= hexToRgb(u_int1);
}
vec3 result = blend(original.rgb, glow, u_int0);
result += (noise - 0.5) * (1.0 / 255.0);
fragColor = vec4(clamp(result, 0.0, 1.0), original.a);
}

View File

@ -0,0 +1,222 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform int u_int0; // Mode: 0=Master, 1=Reds, 2=Yellows, 3=Greens, 4=Cyans, 5=Blues, 6=Magentas, 7=Colorize
uniform int u_int1; // Color Space: 0=HSL, 1=HSB/HSV
uniform float u_float0; // Hue (-180 to 180)
uniform float u_float1; // Saturation (-100 to 100)
uniform float u_float2; // Lightness/Brightness (-100 to 100)
uniform float u_float3; // Overlap (0 to 100) - feathering between adjacent color ranges
in vec2 v_texCoord;
out vec4 fragColor;
// Color range modes
const int MODE_MASTER = 0;
const int MODE_RED = 1;
const int MODE_YELLOW = 2;
const int MODE_GREEN = 3;
const int MODE_CYAN = 4;
const int MODE_BLUE = 5;
const int MODE_MAGENTA = 6;
const int MODE_COLORIZE = 7;
// Color space modes
const int COLORSPACE_HSL = 0;
const int COLORSPACE_HSB = 1;
const float EPSILON = 0.0001;
//=============================================================================
// RGB <-> HSL Conversions
//=============================================================================
vec3 rgb2hsl(vec3 c) {
float maxC = max(max(c.r, c.g), c.b);
float minC = min(min(c.r, c.g), c.b);
float delta = maxC - minC;
float h = 0.0;
float s = 0.0;
float l = (maxC + minC) * 0.5;
if (delta > EPSILON) {
s = l < 0.5
? delta / (maxC + minC)
: delta / (2.0 - maxC - minC);
if (maxC == c.r) {
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
} else if (maxC == c.g) {
h = (c.b - c.r) / delta + 2.0;
} else {
h = (c.r - c.g) / delta + 4.0;
}
h /= 6.0;
}
return vec3(h, s, l);
}
float hue2rgb(float p, float q, float t) {
t = fract(t);
if (t < 1.0/6.0) return p + (q - p) * 6.0 * t;
if (t < 0.5) return q;
if (t < 2.0/3.0) return p + (q - p) * (2.0/3.0 - t) * 6.0;
return p;
}
vec3 hsl2rgb(vec3 hsl) {
if (hsl.y < EPSILON) return vec3(hsl.z);
float q = hsl.z < 0.5
? hsl.z * (1.0 + hsl.y)
: hsl.z + hsl.y - hsl.z * hsl.y;
float p = 2.0 * hsl.z - q;
return vec3(
hue2rgb(p, q, hsl.x + 1.0/3.0),
hue2rgb(p, q, hsl.x),
hue2rgb(p, q, hsl.x - 1.0/3.0)
);
}
vec3 rgb2hsb(vec3 c) {
float maxC = max(max(c.r, c.g), c.b);
float minC = min(min(c.r, c.g), c.b);
float delta = maxC - minC;
float h = 0.0;
float s = (maxC > EPSILON) ? delta / maxC : 0.0;
float b = maxC;
if (delta > EPSILON) {
if (maxC == c.r) {
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
} else if (maxC == c.g) {
h = (c.b - c.r) / delta + 2.0;
} else {
h = (c.r - c.g) / delta + 4.0;
}
h /= 6.0;
}
return vec3(h, s, b);
}
vec3 hsb2rgb(vec3 hsb) {
vec3 rgb = clamp(abs(mod(hsb.x * 6.0 + vec3(0.0, 4.0, 2.0), 6.0) - 3.0) - 1.0, 0.0, 1.0);
return hsb.z * mix(vec3(1.0), rgb, hsb.y);
}
//=============================================================================
// Color Range Weight Calculation
//=============================================================================
float hueDistance(float a, float b) {
float d = abs(a - b);
return min(d, 1.0 - d);
}
float getHueWeight(float hue, float center, float overlap) {
float baseWidth = 1.0 / 6.0;
float feather = baseWidth * overlap;
float d = hueDistance(hue, center);
float inner = baseWidth * 0.5;
float outer = inner + feather;
return 1.0 - smoothstep(inner, outer, d);
}
float getModeWeight(float hue, int mode, float overlap) {
if (mode == MODE_MASTER || mode == MODE_COLORIZE) return 1.0;
if (mode == MODE_RED) {
return max(
getHueWeight(hue, 0.0, overlap),
getHueWeight(hue, 1.0, overlap)
);
}
float center = float(mode - 1) / 6.0;
return getHueWeight(hue, center, overlap);
}
//=============================================================================
// Adjustment Functions
//=============================================================================
float adjustLightness(float l, float amount) {
return amount > 0.0
? l + (1.0 - l) * amount
: l + l * amount;
}
float adjustBrightness(float b, float amount) {
return clamp(b + amount, 0.0, 1.0);
}
float adjustSaturation(float s, float amount) {
return amount > 0.0
? s + (1.0 - s) * amount
: s + s * amount;
}
vec3 colorize(vec3 rgb, float hue, float sat, float light) {
float lum = dot(rgb, vec3(0.299, 0.587, 0.114));
float l = adjustLightness(lum, light);
vec3 hsl = vec3(fract(hue), clamp(sat, 0.0, 1.0), clamp(l, 0.0, 1.0));
return hsl2rgb(hsl);
}
//=============================================================================
// Main
//=============================================================================
void main() {
vec4 original = texture(u_image0, v_texCoord);
float hueShift = u_float0 / 360.0; // -180..180 -> -0.5..0.5
float satAmount = u_float1 / 100.0; // -100..100 -> -1..1
float lightAmount= u_float2 / 100.0; // -100..100 -> -1..1
float overlap = u_float3 / 100.0; // 0..100 -> 0..1
vec3 result;
if (u_int0 == MODE_COLORIZE) {
result = colorize(original.rgb, hueShift, satAmount, lightAmount);
fragColor = vec4(result, original.a);
return;
}
vec3 hsx = (u_int1 == COLORSPACE_HSL)
? rgb2hsl(original.rgb)
: rgb2hsb(original.rgb);
float weight = getModeWeight(hsx.x, u_int0, overlap);
if (u_int0 != MODE_MASTER && hsx.y < EPSILON) {
weight = 0.0;
}
if (weight > EPSILON) {
float h = fract(hsx.x + hueShift * weight);
float s = clamp(adjustSaturation(hsx.y, satAmount * weight), 0.0, 1.0);
float v = (u_int1 == COLORSPACE_HSL)
? clamp(adjustLightness(hsx.z, lightAmount * weight), 0.0, 1.0)
: clamp(adjustBrightness(hsx.z, lightAmount * weight), 0.0, 1.0);
vec3 adjusted = vec3(h, s, v);
result = (u_int1 == COLORSPACE_HSL)
? hsl2rgb(adjusted)
: hsb2rgb(adjusted);
} else {
result = original.rgb;
}
fragColor = vec4(result, original.a);
}

View File

@ -0,0 +1,111 @@
#version 300 es
#pragma passes 2
precision highp float;
// Blur type constants
const int BLUR_GAUSSIAN = 0;
const int BLUR_BOX = 1;
const int BLUR_RADIAL = 2;
// Radial blur config
const int RADIAL_SAMPLES = 12;
const float RADIAL_STRENGTH = 0.0003;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
uniform float u_float0; // Blur radius/amount
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
float gaussian(float x, float sigma) {
return exp(-(x * x) / (2.0 * sigma * sigma));
}
void main() {
vec2 texelSize = 1.0 / u_resolution;
float radius = max(u_float0, 0.0);
// Radial (angular) blur - single pass, doesn't use separable
if (u_int0 == BLUR_RADIAL) {
// Only execute on first pass
if (u_pass > 0) {
fragColor0 = texture(u_image0, v_texCoord);
return;
}
vec2 center = vec2(0.5);
vec2 dir = v_texCoord - center;
float dist = length(dir);
if (dist < 1e-4) {
fragColor0 = texture(u_image0, v_texCoord);
return;
}
vec4 sum = vec4(0.0);
float totalWeight = 0.0;
float angleStep = radius * RADIAL_STRENGTH;
dir /= dist;
float cosStep = cos(angleStep);
float sinStep = sin(angleStep);
float negAngle = -float(RADIAL_SAMPLES) * angleStep;
vec2 rotDir = vec2(
dir.x * cos(negAngle) - dir.y * sin(negAngle),
dir.x * sin(negAngle) + dir.y * cos(negAngle)
);
for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {
vec2 uv = center + rotDir * dist;
float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);
sum += texture(u_image0, uv) * w;
totalWeight += w;
rotDir = vec2(
rotDir.x * cosStep - rotDir.y * sinStep,
rotDir.x * sinStep + rotDir.y * cosStep
);
}
fragColor0 = sum / max(totalWeight, 0.001);
return;
}
// Separable Gaussian / Box blur
int samples = int(ceil(radius));
if (samples == 0) {
fragColor0 = texture(u_image0, v_texCoord);
return;
}
// Direction: pass 0 = horizontal, pass 1 = vertical
vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);
vec4 color = vec4(0.0);
float totalWeight = 0.0;
float sigma = radius / 2.0;
for (int i = -samples; i <= samples; i++) {
vec2 offset = dir * float(i) * texelSize;
vec4 sample_color = texture(u_image0, v_texCoord + offset);
float weight;
if (u_int0 == BLUR_GAUSSIAN) {
weight = gaussian(float(i), sigma);
} else {
// BLUR_BOX
weight = 1.0;
}
color += sample_color * weight;
totalWeight += weight;
}
fragColor0 = color / totalWeight;
}

View File

@ -0,0 +1,19 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
layout(location = 1) out vec4 fragColor1;
layout(location = 2) out vec4 fragColor2;
layout(location = 3) out vec4 fragColor3;
void main() {
vec4 color = texture(u_image0, v_texCoord);
// Output each channel as grayscale to separate render targets
fragColor0 = vec4(vec3(color.r), 1.0); // Red channel
fragColor1 = vec4(vec3(color.g), 1.0); // Green channel
fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel
fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel
}

View File

@ -0,0 +1,71 @@
#version 300 es
precision highp float;
// Levels Adjustment
// u_int0: channel (0=RGB, 1=R, 2=G, 3=B) default: 0
// u_float0: input black (0-255) default: 0
// u_float1: input white (0-255) default: 255
// u_float2: gamma (0.01-9.99) default: 1.0
// u_float3: output black (0-255) default: 0
// u_float4: output white (0-255) default: 255
uniform sampler2D u_image0;
uniform int u_int0;
uniform float u_float0;
uniform float u_float1;
uniform float u_float2;
uniform float u_float3;
uniform float u_float4;
in vec2 v_texCoord;
out vec4 fragColor;
vec3 applyLevels(vec3 color, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
float inRange = max(inWhite - inBlack, 0.0001);
vec3 result = clamp((color - inBlack) / inRange, 0.0, 1.0);
result = pow(result, vec3(1.0 / gamma));
result = mix(vec3(outBlack), vec3(outWhite), result);
return result;
}
float applySingleChannel(float value, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
float inRange = max(inWhite - inBlack, 0.0001);
float result = clamp((value - inBlack) / inRange, 0.0, 1.0);
result = pow(result, 1.0 / gamma);
result = mix(outBlack, outWhite, result);
return result;
}
void main() {
vec4 texColor = texture(u_image0, v_texCoord);
vec3 color = texColor.rgb;
float inBlack = u_float0 / 255.0;
float inWhite = u_float1 / 255.0;
float gamma = u_float2;
float outBlack = u_float3 / 255.0;
float outWhite = u_float4 / 255.0;
vec3 result;
if (u_int0 == 0) {
result = applyLevels(color, inBlack, inWhite, gamma, outBlack, outWhite);
}
else if (u_int0 == 1) {
result = color;
result.r = applySingleChannel(color.r, inBlack, inWhite, gamma, outBlack, outWhite);
}
else if (u_int0 == 2) {
result = color;
result.g = applySingleChannel(color.g, inBlack, inWhite, gamma, outBlack, outWhite);
}
else if (u_int0 == 3) {
result = color;
result.b = applySingleChannel(color.b, inBlack, inWhite, gamma, outBlack, outWhite);
}
else {
result = color;
}
fragColor = vec4(result, texColor.a);
}

View File

@ -0,0 +1,28 @@
# GLSL Shader Sources
This folder contains the GLSL fragment shaders extracted from blueprint JSON files for easier editing and version control.
## File Naming Convention
`{Blueprint_Name}_{node_id}.frag`
- **Blueprint_Name**: The JSON filename with spaces/special chars replaced by underscores
- **node_id**: The GLSLShader node ID within the subgraph
## Usage
```bash
# Extract shaders from blueprint JSONs to this folder
python update_blueprints.py extract
# Patch edited shaders back into blueprint JSONs
python update_blueprints.py patch
```
## Workflow
1. Run `extract` to pull current shaders from JSONs
2. Edit `.frag` files
3. Run `patch` to update the blueprint JSONs
4. Test
5. Commit both `.frag` files and updated JSONs

View File

@ -0,0 +1,28 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // strength [0.0 2.0] typical: 0.31.0
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
void main() {
vec2 texel = 1.0 / u_resolution;
// Sample center and neighbors
vec4 center = texture(u_image0, v_texCoord);
vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));
vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));
vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));
vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));
// Edge enhancement (Laplacian)
vec4 edges = center * 4.0 - top - bottom - left - right;
// Add edges back scaled by strength
vec4 sharpened = center + edges * u_float0;
fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);
}

View File

@ -0,0 +1,61 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
float gaussian(float x, float sigma) {
return exp(-(x * x) / (2.0 * sigma * sigma));
}
float getLuminance(vec3 color) {
return dot(color, vec3(0.2126, 0.7152, 0.0722));
}
void main() {
vec2 texel = 1.0 / u_resolution;
float radius = max(u_float1, 0.5);
float amount = u_float0;
float threshold = u_float2;
vec4 original = texture(u_image0, v_texCoord);
// Gaussian blur for the "unsharp" mask
int samples = int(ceil(radius));
float sigma = radius / 2.0;
vec4 blurred = vec4(0.0);
float totalWeight = 0.0;
for (int x = -samples; x <= samples; x++) {
for (int y = -samples; y <= samples; y++) {
vec2 offset = vec2(float(x), float(y)) * texel;
vec4 sample_color = texture(u_image0, v_texCoord + offset);
float dist = length(vec2(float(x), float(y)));
float weight = gaussian(dist, sigma);
blurred += sample_color * weight;
totalWeight += weight;
}
}
blurred /= totalWeight;
// Unsharp mask = original - blurred
vec3 mask = original.rgb - blurred.rgb;
// Luminance-based threshold with smooth falloff
float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));
float thresholdScale = smoothstep(0.0, threshold, lumaDelta);
mask *= thresholdScale;
// Sharpen: original + mask * amount
vec3 sharpened = original.rgb + mask * amount;
fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);
}

View File

@ -0,0 +1,159 @@
#!/usr/bin/env python3
"""
Shader Blueprint Updater
Syncs GLSL shader files between this folder and blueprint JSON files.
File naming convention:
{Blueprint Name}_{node_id}.frag
Usage:
python update_blueprints.py extract # Extract shaders from JSONs to here
python update_blueprints.py patch # Patch shaders back into JSONs
python update_blueprints.py # Same as patch (default)
"""
import json
import logging
import sys
import re
from pathlib import Path
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)
GLSL_DIR = Path(__file__).parent
BLUEPRINTS_DIR = GLSL_DIR.parent
def get_blueprint_files():
"""Get all blueprint JSON files."""
return sorted(BLUEPRINTS_DIR.glob("*.json"))
def sanitize_filename(name):
"""Convert blueprint name to safe filename."""
return re.sub(r'[^\w\-]', '_', name)
def extract_shaders():
"""Extract all shaders from blueprint JSONs to this folder."""
extracted = 0
for json_path in get_blueprint_files():
blueprint_name = json_path.stem
try:
with open(json_path, 'r') as f:
data = json.load(f)
except (json.JSONDecodeError, IOError) as e:
logger.warning("Skipping %s: %s", json_path.name, e)
continue
# Find GLSLShader nodes in subgraphs
for subgraph in data.get('definitions', {}).get('subgraphs', []):
for node in subgraph.get('nodes', []):
if node.get('type') == 'GLSLShader':
node_id = node.get('id')
widgets = node.get('widgets_values', [])
# Find shader code (first string that looks like GLSL)
for widget in widgets:
if isinstance(widget, str) and widget.startswith('#version'):
safe_name = sanitize_filename(blueprint_name)
frag_name = f"{safe_name}_{node_id}.frag"
frag_path = GLSL_DIR / frag_name
with open(frag_path, 'w') as f:
f.write(widget)
logger.info(" Extracted: %s", frag_name)
extracted += 1
break
logger.info("\nExtracted %d shader(s)", extracted)
def patch_shaders():
"""Patch shaders from this folder back into blueprint JSONs."""
# Build lookup: blueprint_name -> [(node_id, shader_code), ...]
shader_updates = {}
for frag_path in sorted(GLSL_DIR.glob("*.frag")):
# Parse filename: {blueprint_name}_{node_id}.frag
parts = frag_path.stem.rsplit('_', 1)
if len(parts) != 2:
logger.warning("Skipping %s: invalid filename format", frag_path.name)
continue
blueprint_name, node_id_str = parts
try:
node_id = int(node_id_str)
except ValueError:
logger.warning("Skipping %s: invalid node_id", frag_path.name)
continue
with open(frag_path, 'r') as f:
shader_code = f.read()
if blueprint_name not in shader_updates:
shader_updates[blueprint_name] = []
shader_updates[blueprint_name].append((node_id, shader_code))
# Apply updates to JSON files
patched = 0
for json_path in get_blueprint_files():
blueprint_name = sanitize_filename(json_path.stem)
if blueprint_name not in shader_updates:
continue
try:
with open(json_path, 'r') as f:
data = json.load(f)
except (json.JSONDecodeError, IOError) as e:
logger.error("Error reading %s: %s", json_path.name, e)
continue
modified = False
for node_id, shader_code in shader_updates[blueprint_name]:
# Find the node and update
for subgraph in data.get('definitions', {}).get('subgraphs', []):
for node in subgraph.get('nodes', []):
if node.get('id') == node_id and node.get('type') == 'GLSLShader':
widgets = node.get('widgets_values', [])
if len(widgets) > 0 and widgets[0] != shader_code:
widgets[0] = shader_code
modified = True
logger.info(" Patched: %s (node %d)", json_path.name, node_id)
patched += 1
if modified:
with open(json_path, 'w') as f:
json.dump(data, f)
if patched == 0:
logger.info("No changes to apply.")
else:
logger.info("\nPatched %d shader(s)", patched)
def main():
if len(sys.argv) < 2:
command = "patch"
else:
command = sys.argv[1].lower()
if command == "extract":
logger.info("Extracting shaders from blueprints...")
extract_shaders()
elif command in ("patch", "update", "apply"):
logger.info("Patching shaders into blueprints...")
patch_shaders()
else:
logger.info(__doc__)
sys.exit(1)
if __name__ == "__main__":
main()

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

1
blueprints/Glow.json Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1 @@
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1 @@
{"revision": 0, "last_node_id": 15, "last_link_id": 0, "nodes": [{"id": 15, "type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "pos": [-1490, 2040], "size": [400, 260], "flags": {}, "order": 0, "mode": 0, "inputs": [{"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"label": "reference images", "name": "images", "type": "IMAGE", "link": null}], "outputs": [{"name": "STRING", "type": "STRING", "links": null}], "title": "Prompt Enhance", "properties": {"proxyWidgets": [["-1", "prompt"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [""]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 15, "lastLinkId": 14, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Prompt Enhance", "inputNode": {"id": -10, "bounding": [-2170, 2110, 138.876953125, 80]}, "outputNode": {"id": -20, "bounding": [-640, 2110, 120, 60]}, "inputs": [{"id": "aeab7216-00e0-4528-a09b-bba50845c5a6", "name": "prompt", "type": "STRING", "linkIds": [11], "pos": [-2051.123046875, 2130]}, {"id": "7b73fd36-aa31-4771-9066-f6c83879994b", "name": "images", "type": "IMAGE", "linkIds": [14], "label": "reference images", "pos": [-2051.123046875, 2150]}], "outputs": [{"id": "c7b0d930-68a1-48d1-b496-0519e5837064", "name": "STRING", "type": "STRING", "linkIds": [13], "pos": [-620, 2130]}], "widgets": [], "nodes": [{"id": 11, "type": "GeminiNode", "pos": [-1560, 1990], "size": [470, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": 14}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": null}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 11}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.14.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["", "gemini-3-pro-preview", 42, "randomize", "You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 11, "origin_id": -10, "origin_slot": 0, "target_id": 11, "target_slot": 4, "type": "STRING"}, {"id": 13, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "STRING"}, {"id": 14, "origin_id": -10, "origin_slot": 1, "target_id": 11, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Prompt enhance"}]}, "extra": {}}

1
blueprints/Sharpen.json Normal file
View File

@ -0,0 +1 @@
{"revision": 0, "last_node_id": 25, "last_link_id": 0, "nodes": [{"id": 25, "type": "621ba4e2-22a8-482d-a369-023753198b7b", "pos": [4610, -790], "size": [230, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Sharpen", "properties": {"proxyWidgets": [["24", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "621ba4e2-22a8-482d-a369-023753198b7b", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 24, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Sharpen", "inputNode": {"id": -10, "bounding": [4090, -825, 120, 60]}, "outputNode": {"id": -20, "bounding": [5150, -825, 120, 60]}, "inputs": [{"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7", "name": "images.image0", "type": "IMAGE", "linkIds": [34], "localized_name": "images.image0", "label": "image", "pos": [4190, -805]}], "outputs": [{"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [35], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [5170, -805]}], "widgets": [], "nodes": [{"id": 24, "type": "PrimitiveFloat", "pos": [4280, -1240], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "strength", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [36]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [0.5]}, {"id": 23, "type": "GLSLShader", "pos": [4570, -1240], "size": [370, 192], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 34}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 36}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [35]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 2.0] typical: 0.31.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input"]}], "groups": [], "links": [{"id": 36, "origin_id": 24, "origin_slot": 0, "target_id": 23, "target_slot": 2, "type": "FLOAT"}, {"id": 34, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 35, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Sharpen"}]}}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1 @@
{"revision": 0, "last_node_id": 13, "last_link_id": 0, "nodes": [{"id": 13, "type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "pos": [1120, 330], "size": [240, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "title": "Video Upscale(GAN x4)", "properties": {"proxyWidgets": [["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 13, "lastLinkId": 19, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Upscale(GAN x4)", "inputNode": {"id": -10, "bounding": [550, 460, 120, 80]}, "outputNode": {"id": -20, "bounding": [1490, 460, 120, 60]}, "inputs": [{"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6", "name": "video", "type": "VIDEO", "linkIds": [10], "localized_name": "video", "pos": [650, 480]}, {"id": "2e23a087-caa8-4d65-99e6-662761aa905a", "name": "model_name", "type": "COMBO", "linkIds": [19], "pos": [650, 500]}], "outputs": [{"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70", "name": "VIDEO", "type": "VIDEO", "linkIds": [15], "localized_name": "VIDEO", "pos": [1510, 480]}], "widgets": [], "nodes": [{"id": 2, "type": "ImageUpscaleWithModel", "pos": [1110, 450], "size": [320, 46], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "upscale_model", "name": "upscale_model", "type": "UPSCALE_MODEL", "link": 1}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 14}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "ImageUpscaleWithModel"}}, {"id": 11, "type": "CreateVideo", "pos": [1110, 550], "size": [320, 78], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 13}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 16}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 12}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [15]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [30]}, {"id": 10, "type": "GetVideoComponents", "pos": [1110, 330], "size": [320, 70], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 10}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [14]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [16]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [12]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 1, "type": "UpscaleModelLoader", "pos": [750, 450], "size": [280, 60], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 19}], "outputs": [{"localized_name": "UPSCALE_MODEL", "name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "UpscaleModelLoader", "models": [{"name": "RealESRGAN_x4plus.safetensors", "url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors", "directory": "upscale_models"}]}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "groups": [], "links": [{"id": 1, "origin_id": 1, "origin_slot": 0, "target_id": 2, "target_slot": 0, "type": "UPSCALE_MODEL"}, {"id": 14, "origin_id": 10, "origin_slot": 0, "target_id": 2, "target_slot": 1, "type": "IMAGE"}, {"id": 13, "origin_id": 2, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 16, "origin_id": 10, "origin_slot": 1, "target_id": 11, "target_slot": 1, "type": "AUDIO"}, {"id": 12, "origin_id": 10, "origin_slot": 2, "target_id": 11, "target_slot": 2, "type": "FLOAT"}, {"id": 10, "origin_id": -10, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "VIDEO"}, {"id": 15, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 19, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Enhance video"}]}, "extra": {}}

View File

@ -1,13 +0,0 @@
import pickle
load = pickle.load
class Empty:
pass
class Unpickler(pickle.Unpickler):
def find_class(self, module, name):
#TODO: safe unpickle
if module.startswith("pytorch_lightning"):
return Empty
return super().find_class(module, name)

View File

@ -176,6 +176,8 @@ class InputTypeOptions(TypedDict):
"""COMBO type only. Specifies the configuration for a multi-select widget. """COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4 Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987""" https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
gradient_stops: NotRequired[list[list[float]]]
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
class HiddenInputTypeDict(TypedDict): class HiddenInputTypeDict(TypedDict):

View File

@ -297,6 +297,30 @@ class ControlNet(ControlBase):
self.model_sampling_current = None self.model_sampling_current = None
super().cleanup() super().cleanup()
class QwenFunControlNet(ControlNet):
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
# Fun checkpoints are more sensitive to high strengths in the generic
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
# unchanged while >1 grows more gently.
original_strength = self.strength
self.strength = math.sqrt(max(self.strength, 0.0))
try:
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
finally:
self.strength = original_strength
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
self.set_extra_arg("base_model", model.diffusion_model)
def copy(self):
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c
class ControlLoraOps: class ControlLoraOps:
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp): class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True, def __init__(self, in_features: int, out_features: int, bias: bool = True,
@ -560,6 +584,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}): def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options) model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
sd = model_config.process_unet_state_dict(sd)
control_model = controlnet_load_state_dict(control_model, sd) control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance'] extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
@ -605,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control return control
def load_controlnet_qwen_fun(sd, model_options={}):
load_device = comfy.model_management.get_torch_device()
weight_dtype = comfy.utils.weight_dtype(sd)
unet_dtype = model_options.get("dtype", weight_dtype)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
in_features = sd["control_img_in.weight"].shape[1]
inner_dim = sd["control_img_in.weight"].shape[0]
block_weight = sd["control_blocks.0.attn.to_q.weight"]
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
control_in_features=in_features,
inner_dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_control_blocks=5,
main_model_double=60,
injection_layers=(0, 12, 24, 36, 48),
operations=operations,
device=comfy.model_management.unet_offload_device(),
dtype=unet_dtype,
)
model = controlnet_load_state_dict(model, sd)
latent_format = comfy.latent_formats.Wan21()
control = QwenFunControlNet(
model,
compression_ratio=1,
latent_format=latent_format,
# Fun checkpoints already expect their own 33-channel context handling.
# Enabling generic concat_mask injects an extra mask channel at apply-time
# and breaks the intended fallback packing path.
concat_mask=False,
load_device=load_device,
manual_cast_dtype=manual_cast_dtype,
extra_conds=[],
)
return control
def convert_mistoline(sd): def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
@ -682,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options) return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
elif "controlnet_x_embedder.weight" in controlnet_data: elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options) return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options) return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)

View File

@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
if source_attention_mask.ndim == 2: if source_attention_mask.ndim == 2:
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1) source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
x = self.in_proj(self.embed(target_input_ids))
context = source_hidden_states context = source_hidden_states
x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0) position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0) position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
position_embeddings = self.rotary_emb(x, position_ids) position_embeddings = self.rotary_emb(x, position_ids)

View File

@ -3,7 +3,6 @@ from torch import Tensor, nn
from comfy.ldm.flux.layers import ( from comfy.ldm.flux.layers import (
MLPEmbedder, MLPEmbedder,
RMSNorm,
ModulationOut, ModulationOut,
) )
@ -29,7 +28,7 @@ class Approximator(nn.Module):
super().__init__() super().__init__()
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
@property @property

View File

@ -152,6 +152,7 @@ class Chroma(nn.Module):
transformer_options={}, transformer_options={},
attn_mask: Tensor = None, attn_mask: Tensor = None,
) -> Tensor: ) -> Tensor:
transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
# running on sequences img # running on sequences img
@ -228,6 +229,7 @@ class Chroma(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single" transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i transformer_options["block_index"] = i
if i not in self.skip_dit: if i not in self.skip_dit:

View File

@ -4,8 +4,6 @@ from functools import lru_cache
import torch import torch
from torch import nn from torch import nn
from comfy.ldm.flux.layers import RMSNorm
class NerfEmbedder(nn.Module): class NerfEmbedder(nn.Module):
""" """
@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
# We now need to generate parameters for 3 matrices. # We now need to generate parameters for 3 matrices.
total_params = 3 * hidden_size_x**2 * mlp_ratio total_params = 3 * hidden_size_x**2 * mlp_ratio
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device) self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations) self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
class NerfFinalLayer(nn.Module): class NerfFinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device) self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
class NerfFinalLayerConv(nn.Module): class NerfFinalLayerConv(nn.Module):
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None): def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.conv = operations.Conv2d( self.conv = operations.Conv2d(
in_channels=hidden_size, in_channels=hidden_size,
out_channels=out_channels, out_channels=out_channels,

View File

@ -5,9 +5,9 @@ import torch
from torch import Tensor, nn from torch import Tensor, nn
from .math import attention, rope from .math import attention, rope
import comfy.ops
import comfy.ldm.common_dit
# Fix import for some custom nodes, TODO: delete eventually.
RMSNorm = None
class EmbedND(nn.Module): class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list): def __init__(self, dim: int, theta: int, axes_dim: list):
@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
) )
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
class QKNorm(torch.nn.Module): class QKNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None): def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple: def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q) q = self.query_norm(q)
@ -169,7 +161,7 @@ class SiLUActivation(nn.Module):
class DoubleStreamBlock(nn.Module): class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio) mlp_hidden_dim = int(hidden_size * mlp_ratio)
@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module):
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
if self.modulation: if self.modulation:
img_mod1, img_mod2 = self.img_mod(vec) img_mod1, img_mod2 = self.img_mod(vec)
@ -206,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
else: else:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
# prepare image for attention # prepare image for attention
img_modulated = self.img_norm1(img) img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img) img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
@ -224,32 +217,23 @@ class DoubleStreamBlock(nn.Module):
del txt_qkv del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
if self.flipped_img_txt: q = torch.cat((txt_q, img_q), dim=2)
q = torch.cat((img_q, txt_q), dim=2) del txt_q, img_q
del img_q, txt_q k = torch.cat((txt_k, img_k), dim=2)
k = torch.cat((img_k, txt_k), dim=2) del txt_k, img_k
del img_k, txt_k v = torch.cat((txt_v, img_v), dim=2)
v = torch.cat((img_v, txt_v), dim=2) del txt_v, img_v
del img_v, txt_v # run actual attention
# run actual attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
attn = attention(q, k, v, del q, k, v
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:] if "attn1_output_patch" in transformer_patches:
else: extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
q = torch.cat((txt_q, img_q), dim=2) patch = transformer_patches["attn1_output_patch"]
del txt_q, img_q for p in patch:
k = torch.cat((txt_k, img_k), dim=2) attn = p(attn, extra_options)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks # calculate the img bloks
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
@ -328,6 +312,9 @@ class SingleStreamBlock(nn.Module):
else: else:
mod = vec mod = vec
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1) qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@ -337,6 +324,12 @@ class SingleStreamBlock(nn.Module):
# compute attention # compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v del q, k, v
if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
if self.yak_mlp: if self.yak_mlp:
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2] mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]

View File

@ -16,7 +16,6 @@ from .layers import (
SingleStreamBlock, SingleStreamBlock,
timestep_embedding, timestep_embedding,
Modulation, Modulation,
RMSNorm
) )
@dataclass @dataclass
@ -81,7 +80,7 @@ class Flux(nn.Module):
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
if params.txt_norm: if params.txt_norm:
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations) self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
else: else:
self.txt_norm = None self.txt_norm = None
@ -143,6 +142,7 @@ class Flux(nn.Module):
attn_mask: Tensor = None, attn_mask: Tensor = None,
) -> Tensor: ) -> Tensor:
transformer_options = transformer_options.copy()
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3: if img.ndim != 3 or txt.ndim != 3:
@ -232,6 +232,7 @@ class Flux(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single" transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:

View File

@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
self.num_heads, self.num_heads,
mlp_ratio=params.mlp_ratio, mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias, qkv_bias=params.qkv_bias,
flipped_img_txt=True,
dtype=dtype, device=device, operations=operations dtype=dtype, device=device, operations=operations
) )
for _ in range(params.depth) for _ in range(params.depth)
@ -305,6 +304,7 @@ class HunyuanVideo(nn.Module):
control=None, control=None,
transformer_options={}, transformer_options={},
) -> Tensor: ) -> Tensor:
transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
initial_shape = list(img.shape) initial_shape = list(img.shape)
@ -378,14 +378,14 @@ class HunyuanVideo(nn.Module):
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1) txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1) ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids) pe = self.pe_embedder(ids)
img_len = img.shape[1] img_len = img.shape[1]
if txt_mask is not None: if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1] attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device) attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
attn_mask[:, 0, img_len:] = txt_mask attn_mask[:, 0, :txt.shape[1]] = txt_mask
else: else:
attn_mask = None attn_mask = None
@ -413,10 +413,11 @@ class HunyuanVideo(nn.Module):
if add is not None: if add is not None:
img += add img += add
img = torch.cat((img, txt), 1) img = torch.cat((txt, img), 1)
transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single" transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:
@ -435,9 +436,9 @@ class HunyuanVideo(nn.Module):
if i < len(control_o): if i < len(control_o):
add = control_o[i] add = control_o[i]
if add is not None: if add is not None:
img[:, : img_len] += add img[:, txt.shape[1]: img_len + txt.shape[1]] += add
img = img[:, : img_len] img = img[:, txt.shape[1]: img_len + txt.shape[1]]
if ref_latent is not None: if ref_latent is not None:
img = img[:, ref_latent.shape[1]:] img = img[:, ref_latent.shape[1]:]

View File

@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import (
LTXVModel, LTXVModel,
) )
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import comfy.ldm.common_dit import comfy.ldm.common_dit
class CompressedTimestep: class CompressedTimestep:
@ -450,6 +451,29 @@ class LTXAVModel(LTXVModel):
operations=self.operations, operations=self.operations,
) )
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=self.operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=self.operations,
)
def preprocess_text_embeds(self, context):
if context.shape[-1] == self.caption_channels * 2:
return context
out_vid = self.video_embeddings_connector(context)[0]
out_audio = self.audio_embeddings_connector(context)[0]
return torch.concat((out_vid, out_audio), dim=-1)
def _init_transformer_blocks(self, device, dtype, **kwargs): def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks for LTXAV.""" """Initialize transformer blocks for LTXAV."""
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(

View File

@ -157,11 +157,9 @@ class Embeddings1DConnector(nn.Module):
self.num_learnable_registers = num_learnable_registers self.num_learnable_registers = num_learnable_registers
if self.num_learnable_registers: if self.num_learnable_registers:
self.learnable_registers = nn.Parameter( self.learnable_registers = nn.Parameter(
torch.rand( torch.empty(
self.num_learnable_registers, inner_dim, dtype=dtype, device=device self.num_learnable_registers, inner_dim, dtype=dtype, device=device
) )
* 2.0
- 1.0
) )
def get_fractional_positions(self, indices_grid): def get_fractional_positions(self, indices_grid):
@ -234,7 +232,7 @@ class Embeddings1DConnector(nn.Module):
return indices return indices
def precompute_freqs_cis(self, indices_grid, spacing="exp"): def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
dim = self.inner_dim dim = self.inner_dim
n_elem = 2 # 2 because of cos and sin n_elem = 2 # 2 because of cos and sin
freqs = self.precompute_freqs(indices_grid, spacing) freqs = self.precompute_freqs(indices_grid, spacing)
@ -247,7 +245,7 @@ class Embeddings1DConnector(nn.Module):
) )
else: else:
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
def forward( def forward(
self, self,
@ -288,7 +286,7 @@ class Embeddings1DConnector(nn.Module):
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
) )
indices_grid = indices_grid[None, None, :] indices_grid = indices_grid[None, None, :]
freqs_cis = self.precompute_freqs_cis(indices_grid) freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
# 2. Blocks # 2. Blocks
for block_idx, block in enumerate(self.transformer_1d_blocks): for block_idx, block in enumerate(self.transformer_1d_blocks):

View File

@ -102,19 +102,7 @@ class VideoConv3d(nn.Module):
return self.conv(x) return self.conv(x)
def interpolate_up(x, scale_factor): def interpolate_up(x, scale_factor):
try: return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
except: #operation not implemented for bf16
orig_shape = list(x.shape)
out_shape = orig_shape[:2]
for i in range(len(orig_shape) - 2):
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
return out
class Upsample(nn.Module): class Upsample(nn.Module):
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0): def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):

View File

@ -2,6 +2,196 @@ import torch
import math import math
from .model import QwenImageTransformer2DModel from .model import QwenImageTransformer2DModel
from .model import QwenImageTransformerBlock
class QwenImageFunControlBlock(QwenImageTransformerBlock):
def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
super().__init__(
dim=dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dtype=dtype,
device=device,
operations=operations,
)
self.has_before_proj = has_before_proj
if has_before_proj:
self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
class QwenImageFunControlNetModel(torch.nn.Module):
def __init__(
self,
control_in_features=132,
inner_dim=3072,
num_attention_heads=24,
attention_head_dim=128,
num_control_blocks=5,
main_model_double=60,
injection_layers=(0, 12, 24, 36, 48),
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.main_model_double = main_model_double
self.injection_layers = tuple(injection_layers)
# Keep base hint scaling at 1.0 so user-facing strength behaves similarly
# to the reference Gen2/VideoX implementation around strength=1.
self.hint_scale = 1.0
self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
self.control_blocks = torch.nn.ModuleList([])
for i in range(num_control_blocks):
self.control_blocks.append(
QwenImageFunControlBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
has_before_proj=(i == 0),
dtype=dtype,
device=device,
operations=operations,
)
)
def _process_hint_tokens(self, hint):
if hint is None:
return None
if hint.ndim == 4:
hint = hint.unsqueeze(2)
# Fun checkpoints are trained with 33 latent channels before 2x2 packing:
# [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
# Default behavior (no inpaint input in stock Apply ControlNet) should use
# zeros for mask/inpaint branches, matching VideoX fallback semantics.
expected_c = self.control_img_in.weight.shape[1] // 4
if hint.shape[1] == 16 and expected_c == 33:
zeros_mask = torch.zeros_like(hint[:, :1])
zeros_inpaint = torch.zeros_like(hint)
hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
bs, c, t, h, w = hint.shape
hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(
orig_shape[0],
orig_shape[1],
orig_shape[-3],
orig_shape[-2] // 2,
2,
orig_shape[-1] // 2,
2,
)
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
hidden_states = hidden_states.reshape(
bs,
t * ((h + 1) // 2) * ((w + 1) // 2),
c * 4,
)
expected_in = self.control_img_in.weight.shape[1]
cur_in = hidden_states.shape[-1]
if cur_in < expected_in:
pad = torch.zeros(
(hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
hidden_states = torch.cat([hidden_states, pad], dim=-1)
elif cur_in > expected_in:
hidden_states = hidden_states[:, :, :expected_in]
return hidden_states
def forward(
self,
x,
timesteps,
context,
attention_mask=None,
guidance: torch.Tensor = None,
hint=None,
transformer_options={},
base_model=None,
**kwargs,
):
if base_model is None:
raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
encoder_hidden_states_mask = attention_mask
# Keep attention mask disabled inside Fun control blocks to mirror
# VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
encoder_hidden_states_mask = None
hidden_states, img_ids, _ = base_model.process_img(x)
hint_tokens = self._process_hint_tokens(hint)
if hint_tokens is None:
raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
if hint_tokens.shape[1] != hidden_states.shape[1]:
max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
hint_tokens = hint_tokens[:, :max_tokens]
hidden_states = hidden_states[:, :max_tokens]
img_ids = img_ids[:, :max_tokens]
txt_start = round(
max(
((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
)
)
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
hidden_states = base_model.img_in(hidden_states)
encoder_hidden_states = base_model.txt_norm(context)
encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance * 1000
temb = (
base_model.time_text_embed(timesteps, hidden_states)
if guidance is None
else base_model.time_text_embed(timesteps, guidance, hidden_states)
)
c = self.control_img_in(hint_tokens)
for i, block in enumerate(self.control_blocks):
if i == 0:
c_in = block.before_proj(c) + hidden_states
all_c = []
else:
all_c = list(torch.unbind(c, dim=0))
c_in = all_c.pop(-1)
encoder_hidden_states, c_out = block(
hidden_states=c_in,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
c_skip = block.after_proj(c_out) * self.hint_scale
all_c += [c_skip, c_out]
c = torch.stack(all_c, dim=0)
hints = torch.unbind(c, dim=0)[:-1]
controlnet_block_samples = [None] * self.main_model_double
for local_idx, base_idx in enumerate(self.injection_layers):
if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
controlnet_block_samples[base_idx] = hints[local_idx]
return {"input": controlnet_block_samples}
class QwenImageControlNetModel(QwenImageTransformer2DModel): class QwenImageControlNetModel(QwenImageTransformer2DModel):

View File

@ -374,6 +374,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
return padded_tensor return padded_tensor
def calculate_shape(patches, weight, key, original_weights=None):
current_shape = weight.shape
for p in patches:
v = p[1]
offset = p[3]
# Offsets restore the old shape; lists force a diff without metadata
if offset is not None or isinstance(v, list):
continue
if isinstance(v, weight_adapter.WeightAdapterBase):
adapter_shape = v.calculate_shape(key)
if adapter_shape is not None:
current_shape = adapter_shape
continue
# Standard diff logic with padding
if len(v) == 2:
patch_type, patch_data = v[0], v[1]
if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
current_shape = patch_data[0].shape
return current_shape
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None): def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
for p in patches: for p in patches:
strength = p[0] strength = p[0]

View File

@ -5,7 +5,7 @@ import comfy.utils
def convert_lora_bfl_control(sd): #BFL loras for Flux def convert_lora_bfl_control(sd): #BFL loras for Flux
sd_out = {} sd_out = {}
for k in sd: for k in sd:
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight")) k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
sd_out[k_to] = sd[k] sd_out[k_to] = sd[k]
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]]) sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])

View File

@ -78,4 +78,4 @@ def interpret_gathered_like(tensors, gathered):
return dest_views return dest_views
aimdo_allocator = None aimdo_enabled = False

View File

@ -178,10 +178,7 @@ class BaseModel(torch.nn.Module):
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1) xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
context = c_crossattn context = c_crossattn
dtype = self.get_dtype() dtype = self.get_dtype_inference()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
xc = xc.to(dtype) xc = xc.to(dtype)
device = xc.device device = xc.device
@ -218,6 +215,13 @@ class BaseModel(torch.nn.Module):
def get_dtype(self): def get_dtype(self):
return self.diffusion_model.dtype return self.diffusion_model.dtype
def get_dtype_inference(self):
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
return dtype
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
return None return None
@ -372,9 +376,7 @@ class BaseModel(torch.nn.Module):
input_shapes += shape input_shapes += shape
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype() dtype = self.get_dtype_inference()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked #TODO: this needs to be tweaked
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024) return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
@ -986,10 +988,14 @@ class LTXAV(BaseModel):
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None) attention_mask = kwargs.get("attention_mask", None)
device = kwargs["device"]
if attention_mask is not None: if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None) cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None: if cross_attn is not None:
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
@ -1165,7 +1171,7 @@ class Anima(BaseModel):
t5xxl_ids = t5xxl_ids.unsqueeze(0) t5xxl_ids = t5xxl_ids.unsqueeze(0)
if torch.is_inference_mode_enabled(): # if not we are training if torch.is_inference_mode_enabled(): # if not we are training
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype())) cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
else: else:
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids) out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights) out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)

View File

@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1 count += 1
return count return count
def any_suffix_in(keys, prefix, main, suffix_list=[]):
for x in suffix_list:
if "{}{}{}".format(prefix, main, x) in keys:
return True
return False
def calculate_transformer_depth(prefix, state_dict_keys, state_dict): def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None context_dim = None
use_linear_in_transformer = False use_linear_in_transformer = False
@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["meanflow_sum"] = False dit_config["meanflow_sum"] = False
return dit_config return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {} dit_config = {}
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys: if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "flux2" dit_config["image_model"] = "flux2"
@ -241,7 +247,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
dit_config["image_model"] = "chroma" dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64 dit_config["in_channels"] = 64
dit_config["out_channels"] = 64 dit_config["out_channels"] = 64
@ -249,7 +256,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["out_dim"] = 3072 dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120 dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5 dit_config["n_layers"] = 5
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
dit_config["image_model"] = "chroma_radiance" dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3 dit_config["in_channels"] = 3
dit_config["out_channels"] = 3 dit_config["out_channels"] = 3
@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_depth"] = 4 dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8 dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 512 dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32 dit_config["nerf_embedder_dtype"] = torch.float32
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
dit_config["use_x0"] = True dit_config["use_x0"] = True
@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else: else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
dit_config["txt_ids_dims"] = [1, 2] dit_config["txt_ids_dims"] = [1, 2]

View File

@ -836,7 +836,7 @@ def unet_inital_load_device(parameters, dtype):
mem_dev = get_free_memory(torch_dev) mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev) mem_cpu = get_free_memory(cpu_dev)
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None: if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled:
return torch_dev return torch_dev
else: else:
return cpu_dev return cpu_dev
@ -1121,7 +1121,6 @@ def get_cast_buffer(offload_stream, device, size, ref):
synchronize() synchronize()
del STREAM_CAST_BUFFERS[offload_stream] del STREAM_CAST_BUFFERS[offload_stream]
del cast_buffer del cast_buffer
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
soft_empty_cache() soft_empty_cache()
with wf_context: with wf_context:
cast_buffer = torch.empty((size), dtype=torch.int8, device=device) cast_buffer = torch.empty((size), dtype=torch.int8, device=device)

View File

@ -276,6 +276,7 @@ class ModelPatcher:
self.is_clip = False self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
self.cached_patcher_init: tuple[Callable, tuple] | None = None
if not hasattr(self.model, 'model_loaded_weight_memory'): if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0 self.model.model_loaded_weight_memory = 0
@ -312,8 +313,15 @@ class ModelPatcher:
def get_free_memory(self, device): def get_free_memory(self, device):
return comfy.model_management.get_free_memory(device) return comfy.model_management.get_free_memory(device)
def clone(self): def clone(self, disable_dynamic=False):
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) class_ = self.__class__
model = self.model
if self.is_dynamic() and disable_dynamic:
class_ = ModelPatcher
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
model = temp_model_patcher.model
n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
@ -367,6 +375,8 @@ class ModelPatcher:
n.is_clip = self.is_clip n.is_clip = self.is_clip
n.hook_mode = self.hook_mode n.hook_mode = self.hook_mode
n.cached_patcher_init = self.cached_patcher_init
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n) callback(self, n)
return n return n
@ -411,13 +421,16 @@ class ModelPatcher:
def memory_required(self, input_shape): def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape) return self.model.memory_required(input_shape=input_shape)
def disable_model_cfg1_optimization(self):
self.model_options["disable_cfg1_optimization"] = True
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3: if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else: else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization: if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True self.disable_model_cfg1_optimization()
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization) self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
@ -684,18 +697,19 @@ class ModelPatcher:
for key in list(self.pinned): for key in list(self.pinned):
self.unpin_weight(key) self.unpin_weight(key)
def _load_list(self, prio_comfy_cast_weights=False): def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
loading = [] loading = []
for n, m in self.model.named_modules(): for n, m in self.model.named_modules():
params = [] default = False
skip = False params = { name: param for name, param in m.named_parameters(recurse=False) }
for name, param in m.named_parameters(recurse=False):
params.append(name)
for name, param in m.named_parameters(recurse=True): for name, param in m.named_parameters(recurse=True):
if name not in params: if name not in params:
skip = True # skip random weights in non leaf modules default = True # default random weights in non leaf modules
break break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): if default and default_device is not None:
for param in params.values():
param.data = param.data.to(device=default_device)
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
module_mem = comfy.model_management.module_size(m) module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
@ -1500,7 +1514,7 @@ class ModelPatcherDynamic(ModelPatcher):
#with pin and unpin syncrhonization which can be expensive for small weights #with pin and unpin syncrhonization which can be expensive for small weights
#with a high layer rate (e.g. autoregressive LLMs). #with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse). #prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True) loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
loading.sort(reverse=True) loading.sort(reverse=True)
for x in loading: for x in loading:
@ -1518,8 +1532,10 @@ class ModelPatcherDynamic(ModelPatcher):
weight, _, _ = get_key_weight(self.model, key) weight, _, _ = get_key_weight(self.model, key)
if weight is None: if weight is None:
return 0 return (False, 0)
if key in self.patches: if key in self.patches:
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
return (True, 0)
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
num_patches += 1 num_patches += 1
else: else:
@ -1533,7 +1549,13 @@ class ModelPatcherDynamic(ModelPatcher):
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
weight._model_dtype = model_dtype weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
return comfy.memory_management.vram_aligned_size(geometry) return (False, comfy.memory_management.vram_aligned_size(geometry))
def force_load_param(self, param_key, device_to):
key = key_param_name_to_key(n, param_key)
if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to)
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True m.comfy_cast_weights = True
@ -1541,13 +1563,19 @@ class ModelPatcherDynamic(ModelPatcher):
m.seed_key = n m.seed_key = n
set_dirty(m, dirty) set_dirty(m, dirty)
v_weight_size = 0 force_load, v_weight_size = setup_param(self, m, n, "weight")
v_weight_size += setup_param(self, m, n, "weight") force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
v_weight_size += setup_param(self, m, n, "bias") force_load = force_load or force_load_bias
v_weight_size += v_weight_bias
if vbar is not None and not hasattr(m, "_v"): if force_load:
m._v = vbar.alloc(v_weight_size) logging.info(f"Module {n} has resizing Lora - force loading")
allocated_size += v_weight_size force_load_param(self, "weight", device_to)
force_load_param(self, "bias", device_to)
else:
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size
else: else:
for param in params: for param in params:
@ -1565,6 +1593,8 @@ class ModelPatcherDynamic(ModelPatcher):
allocated_size += weight_size allocated_size += weight_size
vbar.set_watermark_limit(allocated_size) vbar.set_watermark_limit(allocated_size)
move_weight_functions(m, device_to)
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
self.model.device = device_to self.model.device = device_to
@ -1584,7 +1614,7 @@ class ModelPatcherDynamic(ModelPatcher):
return 0 if vbar is None else vbar.free_memory(memory_to_free) return 0 if vbar is None else vbar.free_memory(memory_to_free)
def partially_unload_ram(self, ram_to_unload): def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True) loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
for x in loading: for x in loading:
_, _, _, _, m, _ = x _, _, _, _, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m) ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
@ -1605,6 +1635,13 @@ class ModelPatcherDynamic(ModelPatcher):
if unpatch_weights: if unpatch_weights:
self.partially_unload_ram(1e32) self.partially_unload_ram(1e32)
self.partially_unload(None, 1e32) self.partially_unload(None, 1e32)
for m in self.model.modules():
move_weight_functions(m, device_to)
keys = list(self.backup.keys())
for k in keys:
bk = self.backup[k]
comfy.utils.set_attr_param(self.model, k, bk.weight)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above assert not force_patch_weights #See above

View File

@ -21,7 +21,6 @@ import logging
import comfy.model_management import comfy.model_management
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import comfy.float import comfy.float
import comfy.rmsnorm
import json import json
import comfy.memory_management import comfy.memory_management
import comfy.pinned_memory import comfy.pinned_memory
@ -80,7 +79,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
offload_stream = None offload_stream = None
xfer_dest = None xfer_dest = None
@ -171,10 +170,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
#FIXME: this is not accurate, we need to be sensitive to the compute dtype #FIXME: this is not accurate, we need to be sensitive to the compute dtype
x = lowvram_fn(x) x = lowvram_fn(x)
if (isinstance(orig, QuantizedTensor) and if (isinstance(orig, QuantizedTensor) and
(orig.dtype == dtype and len(fns) == 0 or update_weight)): (want_requant and len(fns) == 0 or update_weight)):
seed = comfy.utils.string_to_seed(s.seed_key) seed = comfy.utils.string_to_seed(s.seed_key)
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
if orig.dtype == dtype and len(fns) == 0: if want_requant and len(fns) == 0:
#The layer actually wants our freshly saved QT #The layer actually wants our freshly saved QT
x = y x = y
elif update_weight: elif update_weight:
@ -195,7 +194,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
return weight, bias, (offload_stream, device if signature is not None else None, None) return weight, bias, (offload_stream, device if signature is not None else None, None)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None): def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance. # will add async-offload support to your cast and improve performance.
@ -213,7 +212,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
non_blocking = comfy.model_management.device_supports_non_blocking(device) non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"): if hasattr(s, "_v"):
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype) return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
if offloadable and (device != s.weight.device or if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)): (s.bias is not None and device != s.bias.device)):
@ -463,7 +462,7 @@ class disable_weight_init:
else: else:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp): class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
def reset_parameters(self): def reset_parameters(self):
self.bias = None self.bias = None
return None return None
@ -475,8 +474,7 @@ class disable_weight_init:
weight = None weight = None
bias = None bias = None
offload_stream = None offload_stream = None
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream) uncast_bias_weight(self, weight, bias, offload_stream)
return x return x
@ -829,6 +827,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
else: else:
sd = {} sd = {}
if not hasattr(self, 'weight'):
logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
return sd
if self.bias is not None: if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias sd["{}bias".format(prefix)] = self.bias
@ -852,8 +854,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def _forward(self, input, weight, bias): def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input, compute_dtype=None): def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
x = self._forward(input, weight, bias) x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream) uncast_bias_weight(self, weight, bias, offload_stream)
return x return x
@ -883,8 +885,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
scale = comfy.model_management.cast_to_device(scale, input.device, None) scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
output = self.forward_comfy_cast_weights(input, compute_dtype)
# Reshape output back to 3D if input was 3D # Reshape output back to 3D if input was 3D
if reshaped_3d: if reshaped_3d:

View File

@ -1,57 +1,10 @@
import torch import torch
import comfy.model_management import comfy.model_management
import numbers
import logging
RMSNorm = None
try:
rms_norm_torch = torch.nn.functional.rms_norm
RMSNorm = torch.nn.RMSNorm
except:
rms_norm_torch = None
logging.warning("Please update pytorch to use native RMSNorm")
RMSNorm = torch.nn.RMSNorm
def rms_norm(x, weight=None, eps=1e-6): def rms_norm(x, weight=None, eps=1e-6):
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): if weight is None:
if weight is None: return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
else:
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else: else:
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
if weight is None:
return r
else:
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
if RMSNorm is None:
class RMSNorm(torch.nn.Module):
def __init__(
self,
normalized_shape,
eps=1e-6,
elementwise_affine=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = torch.nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.bias = None
def forward(self, x):
return rms_norm(x, self.weight, self.eps)

View File

@ -423,6 +423,17 @@ class CLIP:
def get_key_patches(self): def get_key_patches(self):
return self.patcher.get_key_patches() return self.patcher.get_key_patches()
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
self.cond_stage_model.reset_clip_options()
self.load_model()
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
class VAE: class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
@ -1182,6 +1193,7 @@ class TEModel(Enum):
JINA_CLIP_2 = 19 JINA_CLIP_2 = 19
QWEN3_8B = 20 QWEN3_8B = 20
QWEN3_06B = 21 QWEN3_06B = 21
GEMMA_3_4B_VISION = 22
def detect_te_model(sd): def detect_te_model(sd):
@ -1210,7 +1222,10 @@ def detect_te_model(sd):
if 'model.layers.47.self_attn.q_norm.weight' in sd: if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd: if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_4B if 'vision_model.embeddings.patch_embedding.weight' in sd:
return TEModel.GEMMA_3_4B_VISION
else:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd: if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias'] weight = sd['model.layers.0.self_attn.k_proj.bias']
@ -1270,6 +1285,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else: else:
if "text_projection" in clip_data[i]: if "text_projection" in clip_data[i]:
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
if "lm_head.weight" in clip_data[i]:
clip_data[i]["model.lm_head.weight"] = clip_data[i].pop("lm_head.weight") # prefix missing in some models
tokenizer_data = {} tokenizer_data = {}
clip_target = EmptyClass() clip_target = EmptyClass()
@ -1335,6 +1352,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b") clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B_VISION:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b_vision")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_12B:
clip_target.clip = comfy.text_encoders.lt.gemma3_te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.Gemma3_12BTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8: elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None) clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
@ -1505,14 +1530,24 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae) return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if out is None: if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if output_model:
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
return out return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None): def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
embedding_directory=embedding_directory,
model_options=model_options,
te_model_options=te_model_options,
disable_dynamic=disable_dynamic)
return model
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
clip = None clip = None
clipvision = None clipvision = None
vae = None vae = None
@ -1561,7 +1596,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model: if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
if output_vae: if output_vae:
@ -1612,7 +1648,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision) return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable_dynamic=False):
""" """
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
@ -1696,7 +1732,8 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
model_config.optimizations["fp8"] = True model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device) ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
if not model_management.is_device_cpu(offload_device): if not model_management.is_device_cpu(offload_device):
model.to(offload_device) model.to(offload_device)
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic()) model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
@ -1705,12 +1742,13 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
logging.info("left over keys in diffusion model: {}".format(left_over)) logging.info("left over keys in diffusion model: {}".format(left_over))
return model_patcher return model_patcher
def load_diffusion_model(unet_path, model_options={}): def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata) model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if model is None: if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
return model return model
def load_unet(unet_path, dtype=None): def load_unet(unet_path, dtype=None):

View File

@ -171,8 +171,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def process_tokens(self, tokens, device): def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None) end_token = self.special_tokens.get("end", None)
pad_token = self.special_tokens.get("pad", -1)
if end_token is None: if end_token is None:
cmp_token = self.special_tokens.get("pad", -1) cmp_token = pad_token
else: else:
cmp_token = end_token cmp_token = end_token
@ -186,15 +187,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
other_embeds = [] other_embeds = []
eos = False eos = False
index = 0 index = 0
left_pad = False
for y in x: for y in x:
if isinstance(y, numbers.Integral): if isinstance(y, numbers.Integral):
if eos: token = int(y)
if index == 0 and token == pad_token:
left_pad = True
if eos or (left_pad and token == pad_token):
attention_mask.append(0) attention_mask.append(0)
else: else:
attention_mask.append(1) attention_mask.append(1)
token = int(y) left_pad = False
tokens_temp += [token] tokens_temp += [token]
if not eos and token == cmp_token: if not eos and token == cmp_token and not left_pad:
if end_token is None: if end_token is None:
attention_mask[-1] = 0 attention_mask[-1] = 0
eos = True eos = True
@ -301,6 +308,15 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def load_sd(self, sd): def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False)) return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
if isinstance(tokens, dict):
tokens_only = next(iter(tokens.values())) # todo: get this better?
else:
tokens_only = tokens
tokens_only = [[t[0] for t in b] for b in tokens_only]
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
def parse_parentheses(string): def parse_parentheses(string):
result = [] result = []
current_item = "" current_item = ""
@ -557,6 +573,8 @@ class SDTokenizer:
min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length) min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding) min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
min_length = kwargs.get("min_length", min_length)
text = escape_important(text) text = escape_important(text)
if kwargs.get("disable_weights", self.disable_weights): if kwargs.get("disable_weights", self.disable_weights):
parsed_weights = [(text, 1.0)] parsed_weights = [(text, 1.0)]
@ -656,6 +674,9 @@ class SDTokenizer:
def state_dict(self): def state_dict(self):
return {} return {}
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
class SD1Tokenizer: class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None): def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
if name is not None: if name is not None:
@ -679,6 +700,9 @@ class SD1Tokenizer:
def state_dict(self): def state_dict(self):
return getattr(self, self.clip).state_dict() return getattr(self, self.clip).state_dict()
def decode(self, token_ids, skip_special_tokens=True):
return getattr(self, self.clip).decode(token_ids, skip_special_tokens=skip_special_tokens)
class SD1CheckpointClipModel(SDClipModel): class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options) super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
@ -715,3 +739,6 @@ class SD1ClipModel(torch.nn.Module):
def load_sd(self, sd): def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd) return getattr(self, self.clip).load_sd(sd)
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)

View File

@ -710,6 +710,15 @@ class Flux(supported_models_base.BASE):
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith("_norm.scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
vae_key_prefix = ["vae."] vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]
@ -898,11 +907,13 @@ class HunyuanVideo(supported_models_base.BASE):
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.") key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.") key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.") key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale") key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale") key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
key_out = key_out.replace("_attn_proj.", "_attn.proj.") key_out = key_out.replace("_attn_proj.", "_attn.proj.")
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.") key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.") key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k] out_sd[key_out] = state_dict[k]
return out_sd return out_sd
@ -1264,6 +1275,15 @@ class Hunyuan3Dv2(supported_models_base.BASE):
latent_format = latent_formats.Hunyuan3Dv2 latent_format = latent_formats.Hunyuan3Dv2
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
def process_unet_state_dict_for_saving(self, state_dict): def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model."} replace_prefix = {"": "model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)
@ -1341,6 +1361,14 @@ class Chroma(supported_models_base.BASE):
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
if key_out.endswith(".scale"):
key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.Chroma(self, device=device) out = model_base.Chroma(self, device=device)

View File

@ -10,7 +10,6 @@ import comfy.utils
def sample_manual_loop_no_classes( def sample_manual_loop_no_classes(
model, model,
ids=None, ids=None,
paddings=[],
execution_dtype=None, execution_dtype=None,
cfg_scale: float = 2.0, cfg_scale: float = 2.0,
temperature: float = 0.85, temperature: float = 0.85,
@ -36,9 +35,6 @@ def sample_manual_loop_no_classes(
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device) embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
embeds_batch = embeds.shape[0] embeds_batch = embeds.shape[0]
for i, t in enumerate(paddings):
attention_mask[i, :t] = 0
attention_mask[i, t:] = 1
output_audio_codes = [] output_audio_codes = []
past_key_values = [] past_key_values = []
@ -135,13 +131,11 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
pos_pad = (len(negative) - len(positive)) pos_pad = (len(negative) - len(positive))
positive = [model.special_tokens["pad"]] * pos_pad + positive positive = [model.special_tokens["pad"]] * pos_pad + positive
paddings = [pos_pad, neg_pad]
ids = [positive, negative] ids = [positive, negative]
else: else:
paddings = []
ids = [positive] ids = [positive]
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer): class ACE15Tokenizer(sd1_clip.SD1Tokenizer):

View File

@ -33,6 +33,8 @@ class AnimaTokenizer:
def state_dict(self): def state_dict(self):
return {} return {}
def decode(self, token_ids, **kwargs):
return self.qwen3_06b.decode(token_ids, **kwargs)
class Qwen3_06BModel(sd1_clip.SDClipModel): class Qwen3_06BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):

View File

@ -3,6 +3,8 @@ import torch.nn as nn
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Any, Tuple from typing import Optional, Any, Tuple
import math import math
from tqdm import tqdm
import comfy.utils
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management import comfy.model_management
@ -103,6 +105,7 @@ class Qwen3_06BConfig:
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass @dataclass
class Qwen3_06B_ACE15_Config: class Qwen3_06B_ACE15_Config:
@ -126,6 +129,7 @@ class Qwen3_06B_ACE15_Config:
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass @dataclass
class Qwen3_2B_ACE15_lm_Config: class Qwen3_2B_ACE15_lm_Config:
@ -149,6 +153,7 @@ class Qwen3_2B_ACE15_lm_Config:
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass @dataclass
class Qwen3_4B_ACE15_lm_Config: class Qwen3_4B_ACE15_lm_Config:
@ -172,6 +177,7 @@ class Qwen3_4B_ACE15_lm_Config:
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass @dataclass
class Qwen3_4BConfig: class Qwen3_4BConfig:
@ -195,6 +201,7 @@ class Qwen3_4BConfig:
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass @dataclass
class Qwen3_8BConfig: class Qwen3_8BConfig:
@ -218,6 +225,7 @@ class Qwen3_8BConfig:
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass @dataclass
class Ovis25_2BConfig: class Ovis25_2BConfig:
@ -288,6 +296,7 @@ class Gemma2_2B_Config:
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False lm_head: bool = False
stop_tokens = [1]
@dataclass @dataclass
class Gemma3_4B_Config: class Gemma3_4B_Config:
@ -312,6 +321,14 @@ class Gemma3_4B_Config:
rope_scale = [8.0, 1.0] rope_scale = [8.0, 1.0]
final_norm: bool = True final_norm: bool = True
lm_head: bool = False lm_head: bool = False
stop_tokens = [1, 106]
GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
@dataclass
class Gemma3_4B_Vision_Config(Gemma3_4B_Config):
vision_config = GEMMA3_VISION_CONFIG
mm_tokens_per_image = 256
@dataclass @dataclass
class Gemma3_12B_Config: class Gemma3_12B_Config:
@ -336,8 +353,9 @@ class Gemma3_12B_Config:
rope_scale = [8.0, 1.0] rope_scale = [8.0, 1.0]
final_norm: bool = True final_norm: bool = True
lm_head: bool = False lm_head: bool = False
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} vision_config = GEMMA3_VISION_CONFIG
mm_tokens_per_image = 256 mm_tokens_per_image = 256
stop_tokens = [1, 106]
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@ -355,13 +373,6 @@ class RMSNorm(nn.Module):
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None): def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
if not isinstance(theta, list): if not isinstance(theta, list):
theta = [theta] theta = [theta]
@ -390,20 +401,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
else: else:
cos = cos.unsqueeze(1) cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1) sin = sin.unsqueeze(1)
out.append((cos, sin)) sin_split = sin.shape[-1] // 2
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
if len(out) == 1: if len(out) == 1:
return out[0] return out[0]
return out return out
def apply_rope(xq, xk, freqs_cis): def apply_rope(xq, xk, freqs_cis):
org_dtype = xq.dtype org_dtype = xq.dtype
cos = freqs_cis[0] cos = freqs_cis[0]
sin = freqs_cis[1] sin = freqs_cis[1]
q_embed = (xq * cos) + (rotate_half(xq) * sin) nsin = freqs_cis[2]
k_embed = (xk * cos) + (rotate_half(xk) * sin)
q_embed = (xq * cos)
q_split = q_embed.shape[-1] // 2
q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
k_embed = (xk * cos)
k_split = k_embed.shape[-1] // 2
k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
return q_embed.to(org_dtype), k_embed.to(org_dtype) return q_embed.to(org_dtype), k_embed.to(org_dtype)
@ -438,8 +459,10 @@ class Attention(nn.Module):
freqs_cis: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None, optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
sliding_window: Optional[int] = None,
): ):
batch_size, seq_length, _ = hidden_states.shape batch_size, seq_length, _ = hidden_states.shape
xq = self.q_proj(hidden_states) xq = self.q_proj(hidden_states)
xk = self.k_proj(hidden_states) xk = self.k_proj(hidden_states)
xv = self.v_proj(hidden_states) xv = self.v_proj(hidden_states)
@ -474,6 +497,11 @@ class Attention(nn.Module):
else: else:
present_key_value = (xk, xv, index + num_tokens) present_key_value = (xk, xv, index + num_tokens)
if sliding_window is not None and xk.shape[2] > sliding_window:
xk = xk[:, :, -sliding_window:]
xv = xv[:, :, -sliding_window:]
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
@ -556,10 +584,12 @@ class TransformerBlockGemma2(nn.Module):
optimized_attention=None, optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
): ):
sliding_window = None
if self.transformer_type == 'gemma3': if self.transformer_type == 'gemma3':
if self.sliding_attention: if self.sliding_attention:
sliding_window = self.sliding_attention
if x.shape[1] > self.sliding_attention: if x.shape[1] > self.sliding_attention:
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype) sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype)
sliding_mask.tril_(diagonal=-self.sliding_attention) sliding_mask.tril_(diagonal=-self.sliding_attention)
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask + sliding_mask attention_mask = attention_mask + sliding_mask
@ -578,6 +608,7 @@ class TransformerBlockGemma2(nn.Module):
freqs_cis=freqs_cis, freqs_cis=freqs_cis,
optimized_attention=optimized_attention, optimized_attention=optimized_attention,
past_key_value=past_key_value, past_key_value=past_key_value,
sliding_window=sliding_window,
) )
x = self.post_attention_layernorm(x) x = self.post_attention_layernorm(x)
@ -762,6 +793,107 @@ class BaseLlama:
def forward(self, input_ids, *args, **kwargs): def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs) return self.model(input_ids, *args, **kwargs)
class BaseGenerate:
def logits(self, x):
input = x[:, -1:]
if hasattr(self.model, "lm_head"):
module = self.model.lm_head
else:
module = self.model.embed_tokens
offload_stream = None
if module.comfy_cast_weights:
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
else:
weight = self.model.embed_tokens.weight.to(x)
x = torch.nn.functional.linear(input, weight, None)
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
device = embeds.device
model_config = self.model.config
if stop_tokens is None:
stop_tokens = self.model.config.stop_tokens
if execution_dtype is None:
if comfy.model_management.should_use_bf16(device):
execution_dtype = torch.bfloat16
else:
execution_dtype = torch.float32
embeds = embeds.to(execution_dtype)
if embeds.ndim == 2:
embeds = embeds.unsqueeze(0)
past_key_values = [] #kv_cache init
max_cache_len = embeds.shape[1] + max_length
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
generated_token_ids = []
pbar = comfy.utils.ProgressBar(max_length)
# Generation loop
for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
token_id = next_token[0].item()
generated_token_ids.append(token_id)
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
pbar.update(1)
if token_id in stop_tokens:
break
return generated_token_ids
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
if not do_sample or temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
# Sampling mode
if repetition_penalty != 1.0:
for i in range(logits.shape[0]):
for token_id in set(token_history):
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = torch.finfo(logits.dtype).min
if min_p > 0.0:
probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)
top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True)
min_threshold = min_p * top_probs
indices_to_remove = probs_before_filter < min_threshold
logits[indices_to_remove] = torch.finfo(logits.dtype).min
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 0] = False
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = torch.finfo(logits.dtype).min
probs = torch.nn.functional.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1, generator=generator)
class BaseQwen3: class BaseQwen3:
def logits(self, x): def logits(self, x):
input = x[:, -1:] input = x[:, -1:]
@ -805,7 +937,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module): class Qwen3_06B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
config = Qwen3_06BConfig(**config_dict) config = Qwen3_06BConfig(**config_dict)
@ -832,7 +964,7 @@ class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module): class Qwen3_4B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
config = Qwen3_4BConfig(**config_dict) config = Qwen3_4BConfig(**config_dict)
@ -850,7 +982,7 @@ class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module): class Qwen3_8B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
config = Qwen3_8BConfig(**config_dict) config = Qwen3_8BConfig(**config_dict)
@ -868,7 +1000,7 @@ class Ovis25_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module): class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
config = Qwen25_7BVLI_Config(**config_dict) config = Qwen25_7BVLI_Config(**config_dict)
@ -878,6 +1010,9 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations) self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
# todo: should this be tied or not?
#self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def preprocess_embed(self, embed, device): def preprocess_embed(self, embed, device):
if embed["type"] == "image": if embed["type"] == "image":
image, grid = qwen_vl.process_qwen2vl_images(embed["data"]) image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
@ -911,7 +1046,7 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids) return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)
class Gemma2_2B(BaseLlama, torch.nn.Module): class Gemma2_2B(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
config = Gemma2_2B_Config(**config_dict) config = Gemma2_2B_Config(**config_dict)
@ -920,7 +1055,7 @@ class Gemma2_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Gemma3_4B(BaseLlama, torch.nn.Module): class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
config = Gemma3_4B_Config(**config_dict) config = Gemma3_4B_Config(**config_dict)
@ -929,7 +1064,25 @@ class Gemma3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Gemma3_12B(BaseLlama, torch.nn.Module): class Gemma3_4B_Vision(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Vision_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
self.image_size = config.vision_config["image_size"]
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
return None, None
class Gemma3_12B(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
config = Gemma3_12B_Config(**config_dict) config = Gemma3_12B_Config(**config_dict)

View File

@ -3,9 +3,9 @@ import os
from transformers import T5TokenizerFast from transformers import T5TokenizerFast
from .spiece_tokenizer import SPieceTokenizer from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.genmo import comfy.text_encoders.genmo
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import torch import torch
import comfy.utils import comfy.utils
import math
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -22,46 +22,86 @@ def ltxv_te(*args, **kwargs):
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs) return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer): class Gemma3_Tokenizer():
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self): def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()} return {"spiece_model": self.tokenizer.serialize_model()}
def tokenize_with_weights(self, text, return_word_ids=False, image=None, llama_template=None, skip_template=True, **kwargs):
self.llama_template = "<start_of_turn>system\nYou are a helpful assistant.<end_of_turn>\n<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n"
self.llama_template_images = "<start_of_turn>system\nYou are a helpful assistant.<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>{}<end_of_turn>\n\n<start_of_turn>model\n"
if image is None:
images = []
else:
samples = image.movedim(-1, 1)
total = int(896 * 896)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
images = [s[:, :, :, :3]]
if text.startswith('<start_of_turn>'):
skip_template = True
if skip_template:
llama_text = text
else:
if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
text_tokens = super().tokenize_with_weights(llama_text, return_word_ids)
if len(images) > 0:
embed_count = 0
for r in text_tokens:
for i, token in enumerate(r):
if token[0] == 262144 and embed_count < len(images):
r[i] = ({"type": "image", "data": images[embed_count]},) + token[1:]
embed_count += 1
return text_tokens
class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
class Gemma3_12BModel(sd1_clip.SDClipModel): class Gemma3_12BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None: if llama_quantization_metadata is not None:
model_options = model_options.copy() model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata model_options["quantization_metadata"] = llama_quantization_metadata
self.dtypes = set()
self.dtypes.add(dtype)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs): def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
text = llama_template.format(text) tokens_only = [[t[0] for t in b] for b in tokens]
text_tokens = super().tokenize_with_weights(text, return_word_ids) embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
embed_count = 0 comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
for k in text_tokens: return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
tt = text_tokens[k]
for r in tt:
for i in range(len(r)):
if r[i][0] == 262144:
if image_embeds is not None and embed_count < image_embeds.shape[0]:
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1
return text_tokens
class LTXAVTEModel(torch.nn.Module): class LTXAVTEModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__() super().__init__()
self.dtypes = set() self.dtypes = set()
self.dtypes.add(dtype) self.dtypes.add(dtype)
self.compat_mode = False
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None) self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
self.dtypes.add(dtype_llama) self.dtypes.add(dtype_llama)
@ -69,6 +109,11 @@ class LTXAVTEModel(torch.nn.Module):
operations = self.gemma3_12b.operations # TODO operations = self.gemma3_12b.operations # TODO
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
def enable_compat_mode(self): # TODO: remove
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
operations = self.gemma3_12b.operations
dtype = self.text_embedding_projection.weight.dtype
device = self.text_embedding_projection.weight.device
self.audio_embeddings_connector = Embeddings1DConnector( self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True, split_rope=True,
double_precision_rope=True, double_precision_rope=True,
@ -84,6 +129,7 @@ class LTXAVTEModel(torch.nn.Module):
device=device, device=device,
operations=operations, operations=operations,
) )
self.compat_mode = True
def set_clip_options(self, options): def set_clip_options(self, options):
self.execution_device = options.get("execution_device", self.execution_device) self.execution_device = options.get("execution_device", self.execution_device)
@ -97,6 +143,7 @@ class LTXAVTEModel(torch.nn.Module):
token_weight_pairs = token_weight_pairs["gemma3_12b"] token_weight_pairs = token_weight_pairs["gemma3_12b"]
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
out = out[:, :, -torch.sum(extra["attention_mask"]).item():]
out_device = out.device out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device): if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16) out = out.to(device=self.execution_device, dtype=torch.bfloat16)
@ -105,30 +152,45 @@ class LTXAVTEModel(torch.nn.Module):
out = out.reshape((out.shape[0], out.shape[1], -1)) out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out) out = self.text_embedding_projection(out)
out = out.float() out = out.float()
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0] if self.compat_mode:
out = torch.concat((out_vid, out_audio), dim=-1) out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
return out.to(out_device), pooled return out.to(out_device), pooled
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
def load_sd(self, sd): def load_sd(self, sd):
if "model.layers.47.self_attn.q_norm.weight" in sd: if "model.layers.47.self_attn.q_norm.weight" in sd:
return self.gemma3_12b.load_sd(sd) return self.gemma3_12b.load_sd(sd)
else: else:
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True) sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True)
if len(sdo) == 0: if len(sdo) == 0:
sdo = sd sdo = sd
missing_all = [] missing_all = []
unexpected_all = [] unexpected_all = []
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]: for prefix, component in [("text_embedding_projection.", self.text_embedding_projection)]:
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)} component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
if component_sd: if component_sd:
missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False)) missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))
missing_all.extend([f"{prefix}{k}" for k in missing]) missing_all.extend([f"{prefix}{k}" for k in missing])
unexpected_all.extend([f"{prefix}{k}" for k in unexpected]) unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
if "model.diffusion_model.audio_embeddings_connector.transformer_1d_blocks.2.attn1.to_q.bias" not in sd: # TODO: remove
ww = sd.get("model.diffusion_model.audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.bias", None)
if ww is not None:
if ww.shape[0] == 3840:
self.enable_compat_mode()
sdv = comfy.utils.state_dict_prefix_replace(sd, {"model.diffusion_model.video_embeddings_connector.": ""}, filter_keys=True)
self.video_embeddings_connector.load_state_dict(sdv, strict=False, assign=getattr(self, "can_assign_sd", False))
sda = comfy.utils.state_dict_prefix_replace(sd, {"model.diffusion_model.audio_embeddings_connector.": ""}, filter_keys=True)
self.audio_embeddings_connector.load_state_dict(sda, strict=False, assign=getattr(self, "can_assign_sd", False))
return (missing_all, unexpected_all) return (missing_all, unexpected_all)
def memory_estimation_function(self, token_weight_pairs, device=None): def memory_estimation_function(self, token_weight_pairs, device=None):
@ -138,6 +200,7 @@ class LTXAVTEModel(torch.nn.Module):
token_weight_pairs = token_weight_pairs.get("gemma3_12b", []) token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
num_tokens = max(num_tokens, 64)
return num_tokens * constant * 1024 * 1024 return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
@ -150,3 +213,14 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
dtype = dtype_llama dtype = dtype_llama
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return LTXAVTEModel_ return LTXAVTEModel_
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
class Gemma3_12BModel_(Gemma3_12BModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Gemma3_12BModel_

View File

@ -1,23 +1,23 @@
from comfy import sd1_clip from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.llama import comfy.text_encoders.llama
from comfy.text_encoders.lt import Gemma3_Tokenizer
import comfy.utils
class Gemma2BTokenizer(sd1_clip.SDTokenizer): class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None) tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) special_tokens = {"<end_of_turn>": 107}
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
def state_dict(self): def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()} return {"spiece_model": self.tokenizer.serialize_model()}
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer): class Gemma3_4BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None) tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data) special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, disable_weights=True, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LuminaTokenizer(sd1_clip.SD1Tokenizer): class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -40,6 +40,20 @@ class Gemma3_4BModel(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def process_tokens(self, tokens, device):
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return embeds
class LuminaModel(sd1_clip.SD1ClipModel): class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel): def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
@ -50,6 +64,8 @@ def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b
model = Gemma2_2BModel model = Gemma2_2BModel
elif model_type == "gemma3_4b": elif model_type == "gemma3_4b":
model = Gemma3_4BModel model = Gemma3_4BModel
elif model_type == "gemma3_4b_vision":
model = Gemma3_4B_Vision_Model
class LuminaTEModel_(LuminaModel): class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options={}):

View File

@ -6,9 +6,10 @@ class SPieceTokenizer:
def from_pretrained(path, **kwargs): def from_pretrained(path, **kwargs):
return SPieceTokenizer(path, **kwargs) return SPieceTokenizer(path, **kwargs)
def __init__(self, tokenizer_path, add_bos=False, add_eos=True): def __init__(self, tokenizer_path, add_bos=False, add_eos=True, special_tokens=None):
self.add_bos = add_bos self.add_bos = add_bos
self.add_eos = add_eos self.add_eos = add_eos
self.special_tokens = special_tokens
import sentencepiece import sentencepiece
if torch.is_tensor(tokenizer_path): if torch.is_tensor(tokenizer_path):
tokenizer_path = tokenizer_path.numpy().tobytes() tokenizer_path = tokenizer_path.numpy().tobytes()
@ -27,8 +28,32 @@ class SPieceTokenizer:
return out return out
def __call__(self, string): def __call__(self, string):
if self.special_tokens is not None:
import re
special_tokens_pattern = '|'.join(re.escape(token) for token in self.special_tokens.keys())
if special_tokens_pattern and re.search(special_tokens_pattern, string):
parts = re.split(f'({special_tokens_pattern})', string)
result = []
for part in parts:
if not part:
continue
if part in self.special_tokens:
result.append(self.special_tokens[part])
else:
encoded = self.tokenizer.encode(part, add_bos=False, add_eos=False)
result.extend(encoded)
return {"input_ids": result}
out = self.tokenizer.encode(string) out = self.tokenizer.encode(string)
return {"input_ids": out} return {"input_ids": out}
def decode(self, token_ids, skip_special_tokens=False):
if skip_special_tokens and self.special_tokens:
special_token_ids = set(self.special_tokens.values())
token_ids = [tid for tid in token_ids if tid not in special_token_ids]
return self.tokenizer.decode(token_ids)
def serialize_model(self): def serialize_model(self):
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto())) return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))

View File

@ -20,7 +20,7 @@
import torch import torch
import math import math
import struct import struct
import comfy.checkpoint_pickle import comfy.memory_management
import safetensors.torch import safetensors.torch
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -38,26 +38,26 @@ import warnings
MMAP_TORCH_FILES = args.mmap_torch_files MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap DISABLE_MMAP = args.disable_mmap
ALWAYS_SAFE_LOAD = False
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated if True: # ckpt/pt file whitelist for safe loading of old sd files
class ModelCheckpoint: class ModelCheckpoint:
pass pass
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint" ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
def scalar(*args, **kwargs): def scalar(*args, **kwargs):
from numpy.core.multiarray import scalar as sc return None
return sc(*args, **kwargs)
scalar.__module__ = "numpy.core.multiarray" scalar.__module__ = "numpy.core.multiarray"
from numpy import dtype from numpy import dtype
from numpy.dtypes import Float64DType from numpy.dtypes import Float64DType
from _codecs import encode
def encode(*args, **kwargs): # no longer necessary on newer torch
return None
encode.__module__ = "_codecs"
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode]) torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.") logging.info("Checkpoint files will always be loaded safely.")
else:
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
# Current as of safetensors 0.7.0 # Current as of safetensors 0.7.0
_TYPES = { _TYPES = {
@ -140,11 +140,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if MMAP_TORCH_FILES: if MMAP_TORCH_FILES:
torch_args["mmap"] = True torch_args["mmap"] = True
if safe_load or ALWAYS_SAFE_LOAD: pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "state_dict" in pl_sd: if "state_dict" in pl_sd:
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"]
else: else:
@ -675,10 +672,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"ff_context.linear_in.bias": "txt_mlp.0.bias", "ff_context.linear_in.bias": "txt_mlp.0.bias",
"ff_context.linear_out.weight": "txt_mlp.2.weight", "ff_context.linear_out.weight": "txt_mlp.2.weight",
"ff_context.linear_out.bias": "txt_mlp.2.bias", "ff_context.linear_out.bias": "txt_mlp.2.bias",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale", "attn.norm_q.weight": "img_attn.norm.query_norm.weight",
"attn.norm_k.weight": "img_attn.norm.key_norm.scale", "attn.norm_k.weight": "img_attn.norm.key_norm.weight",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", "attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", "attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
} }
for k in block_map: for k in block_map:
@ -701,8 +698,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"norm.linear.bias": "modulation.lin.bias", "norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight", "proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias", "proj_out.bias": "linear2.bias",
"attn.norm_q.weight": "norm.query_norm.scale", "attn.norm_q.weight": "norm.query_norm.weight",
"attn.norm_k.weight": "norm.key_norm.scale", "attn.norm_k.weight": "norm.key_norm.weight",
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2 "attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
"attn.to_out.weight": "linear2.weight", # Flux 2 "attn.to_out.weight": "linear2.weight", # Flux 2
} }
@ -1157,7 +1154,7 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
def model_trange(*args, **kwargs): def model_trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None: if not comfy.memory_management.aimdo_enabled:
return trange(*args, **kwargs) return trange(*args, **kwargs)
pbar = trange(*args, **kwargs, smoothing=1.0) pbar = trange(*args, **kwargs, smoothing=1.0)
@ -1421,3 +1418,11 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res memo[obj_id] = res
return res return res
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
"""Normalize image embeddings to match text embedding scale"""
for info in embeds_info:
if info.get("type") == "image":
start_idx = info["index"]
end_idx = start_idx + info["size"]
embeds[:, start_idx:end_idx, :] /= scale_factor

View File

@ -49,6 +49,12 @@ class WeightAdapterBase:
""" """
raise NotImplementedError raise NotImplementedError
def calculate_shape(
self,
key
):
return None
def calculate_weight( def calculate_weight(
self, self,
weight, weight,

View File

@ -214,6 +214,13 @@ class LoRAAdapter(WeightAdapterBase):
else: else:
return None return None
def calculate_shape(
self,
key
):
reshape = self.weights[5]
return tuple(reshape) if reshape is not None else None
def calculate_weight( def calculate_weight(
self, self,
weight, weight,

View File

@ -14,6 +14,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True, "supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}}, "extension": {"manager": {"supports_v4": True}},
"node_replacements": True,
} }

View File

@ -21,6 +21,17 @@ class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest" VERSION = "latest"
STABLE = False STABLE = False
def __init__(self):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
"""Register a node replacement mapping."""
from server import PromptServer
PromptServer.instance.node_replace_manager.register(node_replace)
class Execution(ProxiedSingleton): class Execution(ProxiedSingleton):
async def set_progress( async def set_progress(
self, self,
@ -73,8 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display, image=to_display,
) )
execution: Execution
class ComfyExtension(ABC): class ComfyExtension(ABC):
async def on_load(self) -> None: async def on_load(self) -> None:
""" """

View File

@ -444,7 +444,7 @@ class VideoFromComponents(VideoInput):
output.mux(packet) output.mux(packet)
if audio_stream and self.__components.audio: if audio_stream and self.__components.audio:
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout) frame = av.AudioFrame.from_ndarray(waveform.float().cpu().contiguous().numpy(), format='fltp', layout=layout)
frame.sample_rate = audio_sample_rate frame.sample_rate = audio_sample_rate
frame.pts = 0 frame.pts = 0
output.mux(audio_stream.encode(frame)) output.mux(audio_stream.encode(frame))

View File

@ -73,8 +73,15 @@ class RemoteOptions:
class NumberDisplay(str, Enum): class NumberDisplay(str, Enum):
number = "number" number = "number"
slider = "slider" slider = "slider"
gradient_slider = "gradientslider"
class ControlAfterGenerate(str, Enum):
fixed = "fixed"
increment = "increment"
decrement = "decrement"
randomize = "randomize"
class _ComfyType(ABC): class _ComfyType(ABC):
Type = Any Type = Any
io_type: str = None io_type: str = None
@ -263,7 +270,7 @@ class Int(ComfyTypeIO):
class Input(WidgetInput): class Input(WidgetInput):
'''Integer input.''' '''Integer input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min self.min = min
@ -290,13 +297,15 @@ class Float(ComfyTypeIO):
'''Float input.''' '''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None,
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min self.min = min
self.max = max self.max = max
self.step = step self.step = step
self.round = round self.round = round
self.display_mode = display_mode self.display_mode = display_mode
self.gradient_stops = gradient_stops
self.default: float self.default: float
def as_dict(self): def as_dict(self):
@ -306,6 +315,7 @@ class Float(ComfyTypeIO):
"step": self.step, "step": self.step,
"round": self.round, "round": self.round,
"display": self.display_mode, "display": self.display_mode,
"gradient_stops": self.gradient_stops,
}) })
@comfytype(io_type="STRING") @comfytype(io_type="STRING")
@ -345,7 +355,7 @@ class Combo(ComfyTypeIO):
tooltip: str=None, tooltip: str=None,
lazy: bool=None, lazy: bool=None,
default: str | int | Enum = None, default: str | int | Enum = None,
control_after_generate: bool=None, control_after_generate: bool | ControlAfterGenerate=None,
upload: UploadType=None, upload: UploadType=None,
image_folder: FolderType=None, image_folder: FolderType=None,
remote: RemoteOptions=None, remote: RemoteOptions=None,
@ -389,7 +399,7 @@ class MultiCombo(ComfyTypeI):
Type = list[str] Type = list[str]
class Input(Combo.Input): class Input(Combo.Input):
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool | ControlAfterGenerate=None,
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced) super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
self.multiselect = True self.multiselect = True
@ -1203,6 +1213,30 @@ class Color(ComfyTypeIO):
def as_dict(self): def as_dict(self):
return super().as_dict() return super().as_dict()
@comfytype(io_type="BOUNDING_BOX")
class BoundingBox(ComfyTypeIO):
class BoundingBoxDict(TypedDict):
x: int
y: int
width: int
height: int
Type = BoundingBoxDict
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None, component: str=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless)
self.component = component
if default is None:
self.default = {"x": 0, "y": 0, "width": 512, "height": 512}
def as_dict(self):
d = super().as_dict()
if self.component:
d["component"] = self.component
return d
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func DYNAMIC_INPUT_LOOKUP[io_type] = func
@ -1309,6 +1343,7 @@ class NodeInfoV1:
api_node: bool=None api_node: bool=None
price_badge: dict | None = None price_badge: dict | None = None
search_aliases: list[str]=None search_aliases: list[str]=None
essentials_category: str=None
@dataclass @dataclass
@ -1430,6 +1465,8 @@ class Schema:
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property.""" """Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
accept_all_inputs: bool=False accept_all_inputs: bool=False
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema.""" """When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
essentials_category: str | None = None
"""Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing')."""
def validate(self): def validate(self):
'''Validate the schema: '''Validate the schema:
@ -1536,6 +1573,7 @@ class Schema:
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"), python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None, price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
search_aliases=self.search_aliases if self.search_aliases else None, search_aliases=self.search_aliases if self.search_aliases else None,
essentials_category=self.essentials_category,
) )
return info return info
@ -2030,11 +2068,74 @@ class _UIOutput(ABC):
... ...
class InputMapOldId(TypedDict):
"""Map an old node input to a new node input by ID."""
new_id: str
old_id: str
class InputMapSetValue(TypedDict):
"""Set a specific value for a new node input."""
new_id: str
set_value: Any
InputMap = InputMapOldId | InputMapSetValue
"""
Input mapping for node replacement. Type is inferred by dictionary keys:
- {"new_id": str, "old_id": str} - maps old input to new input
- {"new_id": str, "set_value": Any} - sets a specific value for new input
"""
class OutputMap(TypedDict):
"""Map outputs of node replacement via indexes."""
new_idx: int
old_idx: int
class NodeReplace:
"""
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
Also supports assigning specific values to the input widgets of the new node.
Args:
new_node_id: The class name of the new replacement node.
old_node_id: The class name of the deprecated node.
old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
connected. The workflow JSON stores widget values by their relative position index,
not by ID. This list maps those positional indexes to input IDs, enabling the
replacement system to correctly identify widget values during node migration.
input_mapping: List of input mappings from old node to new node.
output_mapping: List of output mappings from old node to new node.
"""
def __init__(self,
new_node_id: str,
old_node_id: str,
old_widget_ids: list[str] | None=None,
input_mapping: list[InputMap] | None=None,
output_mapping: list[OutputMap] | None=None,
):
self.new_node_id = new_node_id
self.old_node_id = old_node_id
self.old_widget_ids = old_widget_ids
self.input_mapping = input_mapping
self.output_mapping = output_mapping
def as_dict(self):
"""Create serializable representation of the node replacement."""
return {
"new_node_id": self.new_node_id,
"old_node_id": self.old_node_id,
"old_widget_ids": self.old_widget_ids,
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
}
__all__ = [ __all__ = [
"FolderType", "FolderType",
"UploadType", "UploadType",
"RemoteOptions", "RemoteOptions",
"NumberDisplay", "NumberDisplay",
"ControlAfterGenerate",
"comfytype", "comfytype",
"Custom", "Custom",
@ -2121,4 +2222,6 @@ __all__ = [
"ImageCompare", "ImageCompare",
"PriceBadgeDepends", "PriceBadgeDepends",
"PriceBadge", "PriceBadge",
"BoundingBox",
"NodeReplace",
] ]

View File

@ -45,17 +45,55 @@ class BriaEditImageRequest(BaseModel):
) )
class BriaRemoveBackgroundRequest(BaseModel):
image: str = Field(...)
sync: bool = Field(False)
visual_input_content_moderation: bool = Field(
False, description="If true, returns 422 on input image moderation failure."
)
visual_output_content_moderation: bool = Field(
False, description="If true, returns 422 on visual output moderation failure."
)
seed: int = Field(...)
class BriaStatusResponse(BaseModel): class BriaStatusResponse(BaseModel):
request_id: str = Field(...) request_id: str = Field(...)
status_url: str = Field(...) status_url: str = Field(...)
warning: str | None = Field(None) warning: str | None = Field(None)
class BriaResult(BaseModel): class BriaRemoveBackgroundResult(BaseModel):
image_url: str = Field(...)
class BriaRemoveBackgroundResponse(BaseModel):
status: str = Field(...)
result: BriaRemoveBackgroundResult | None = Field(None)
class BriaImageEditResult(BaseModel):
structured_prompt: str = Field(...) structured_prompt: str = Field(...)
image_url: str = Field(...) image_url: str = Field(...)
class BriaResponse(BaseModel): class BriaImageEditResponse(BaseModel):
status: str = Field(...) status: str = Field(...)
result: BriaResult | None = Field(None) result: BriaImageEditResult | None = Field(None)
class BriaRemoveVideoBackgroundRequest(BaseModel):
video: str = Field(...)
background_color: str = Field(default="transparent", description="Background color for the output video.")
output_container_and_codec: str = Field(...)
preserve_audio: bool = Field(True)
seed: int = Field(...)
class BriaRemoveVideoBackgroundResult(BaseModel):
video_url: str = Field(...)
class BriaRemoveVideoBackgroundResponse(BaseModel):
status: str = Field(...)
result: BriaRemoveVideoBackgroundResult | None = Field(None)

View File

@ -27,6 +27,7 @@ class Seedream4TaskCreationRequest(BaseModel):
sequential_image_generation: str = Field("disabled") sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(False) watermark: bool = Field(False)
output_format: str | None = None
class ImageTaskCreationResponse(BaseModel): class ImageTaskCreationResponse(BaseModel):
@ -106,6 +107,7 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
("2496x1664 (3:2)", 2496, 1664), ("2496x1664 (3:2)", 2496, 1664),
("1664x2496 (2:3)", 1664, 2496), ("1664x2496 (2:3)", 1664, 2496),
("3024x1296 (21:9)", 3024, 1296), ("3024x1296 (21:9)", 3024, 1296),
("3072x3072 (1:1)", 3072, 3072),
("4096x4096 (1:1)", 4096, 4096), ("4096x4096 (1:1)", 4096, 4096),
("Custom", None, None), ("Custom", None, None),
] ]

View File

@ -0,0 +1,88 @@
from pydantic import BaseModel, Field
class SpeechToTextRequest(BaseModel):
model_id: str = Field(...)
cloud_storage_url: str = Field(...)
language_code: str | None = Field(None, description="ISO-639-1 or ISO-639-3 language code")
tag_audio_events: bool | None = Field(None, description="Annotate sounds like (laughter) in transcript")
num_speakers: int | None = Field(None, description="Max speakers predicted")
timestamps_granularity: str = Field(default="word", description="Timing precision: none, word, or character")
diarize: bool | None = Field(None, description="Annotate which speaker is talking")
diarization_threshold: float | None = Field(None, description="Speaker separation sensitivity")
temperature: float | None = Field(None, description="Randomness control")
seed: int = Field(..., description="Seed for deterministic sampling")
class SpeechToTextWord(BaseModel):
text: str = Field(..., description="The word text")
type: str = Field(default="word", description="Type of text element (word, spacing, etc.)")
start: float | None = Field(None, description="Start time in seconds (when timestamps enabled)")
end: float | None = Field(None, description="End time in seconds (when timestamps enabled)")
speaker_id: str | None = Field(None, description="Speaker identifier when diarization is enabled")
logprob: float | None = Field(None, description="Log probability of the word")
class SpeechToTextResponse(BaseModel):
language_code: str = Field(..., description="Detected or specified language code")
language_probability: float | None = Field(None, description="Confidence of language detection")
text: str = Field(..., description="Full transcript text")
words: list[SpeechToTextWord] | None = Field(None, description="Word-level timing information")
class TextToSpeechVoiceSettings(BaseModel):
stability: float | None = Field(None, description="Voice stability")
similarity_boost: float | None = Field(None, description="Similarity boost")
style: float | None = Field(None, description="Style exaggeration")
use_speaker_boost: bool | None = Field(None, description="Boost similarity to original speaker")
speed: float | None = Field(None, description="Speech speed")
class TextToSpeechRequest(BaseModel):
text: str = Field(..., description="Text to convert to speech")
model_id: str = Field(..., description="Model ID for TTS")
language_code: str | None = Field(None, description="ISO-639-1 or ISO-639-3 language code")
voice_settings: TextToSpeechVoiceSettings | None = Field(None, description="Voice settings")
seed: int = Field(..., description="Seed for deterministic sampling")
apply_text_normalization: str | None = Field(None, description="Text normalization mode: auto, on, off")
class TextToSoundEffectsRequest(BaseModel):
text: str = Field(..., description="Text prompt to convert into a sound effect")
duration_seconds: float = Field(..., description="Duration of generated sound in seconds")
prompt_influence: float = Field(..., description="How closely generation follows the prompt")
loop: bool | None = Field(None, description="Whether to create a smoothly looping sound effect")
class AddVoiceRequest(BaseModel):
name: str = Field(..., description="Name that identifies the voice")
remove_background_noise: bool = Field(..., description="Remove background noise from voice samples")
class AddVoiceResponse(BaseModel):
voice_id: str = Field(..., description="The newly created voice's unique identifier")
class SpeechToSpeechRequest(BaseModel):
model_id: str = Field(..., description="Model ID for speech-to-speech")
voice_settings: str = Field(..., description="JSON string of voice settings")
seed: int = Field(..., description="Seed for deterministic sampling")
remove_background_noise: bool = Field(..., description="Remove background noise from input audio")
class DialogueInput(BaseModel):
text: str = Field(..., description="Text content to convert to speech")
voice_id: str = Field(..., description="Voice identifier for this dialogue segment")
class DialogueSettings(BaseModel):
stability: float | None = Field(None, description="Voice stability (0-1)")
class TextToDialogueRequest(BaseModel):
inputs: list[DialogueInput] = Field(..., description="List of dialogue segments")
model_id: str = Field(..., description="Model ID for dialogue generation")
language_code: str | None = Field(None, description="ISO-639-1 language code")
settings: DialogueSettings | None = Field(None, description="Voice settings")
seed: int | None = Field(None, description="Seed for deterministic sampling")
apply_text_normalization: str | None = Field(None, description="Text normalization mode: auto, on, off")

View File

@ -116,9 +116,15 @@ class GeminiGenerationConfig(BaseModel):
topP: float | None = Field(None, ge=0.0, le=1.0) topP: float | None = Field(None, ge=0.0, le=1.0)
class GeminiImageOutputOptions(BaseModel):
mimeType: str = Field("image/png")
compressionQuality: int | None = Field(None)
class GeminiImageConfig(BaseModel): class GeminiImageConfig(BaseModel):
aspectRatio: str | None = Field(None) aspectRatio: str | None = Field(None)
imageSize: str | None = Field(None) imageSize: str | None = Field(None)
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
class GeminiImageGenerationConfig(GeminiGenerationConfig): class GeminiImageGenerationConfig(GeminiGenerationConfig):

View File

@ -64,3 +64,23 @@ class To3DProTaskResultResponse(BaseModel):
class To3DProTaskQueryRequest(BaseModel): class To3DProTaskQueryRequest(BaseModel):
JobId: str = Field(...) JobId: str = Field(...)
class To3DUVFileInput(BaseModel):
Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
Url: str = Field(...)
class To3DUVTaskRequest(BaseModel):
File: To3DUVFileInput = Field(...)
class TextureEditImageInfo(BaseModel):
Url: str = Field(...)
class TextureEditTaskRequest(BaseModel):
File3D: To3DUVFileInput = Field(...)
Image: TextureEditImageInfo | None = Field(None)
Prompt: str | None = Field(None)
EnablePBR: bool | None = Field(None)

View File

@ -134,6 +134,13 @@ class ImageToVideoWithAudioRequest(BaseModel):
shot_type: str | None = Field(None) shot_type: str | None = Field(None)
class KlingAvatarRequest(BaseModel):
image: str = Field(...)
sound_file: str = Field(...)
prompt: str | None = Field(None)
mode: str = Field(...)
class MotionControlRequest(BaseModel): class MotionControlRequest(BaseModel):
prompt: str = Field(...) prompt: str = Field(...)
image_url: str = Field(...) image_url: str = Field(...)

Some files were not shown because too many files have changed in this diff Show More