mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 19:02:31 +08:00
Merge branch 'master' into FP8-Fast-Fix
This commit is contained in:
commit
56e9de4f9f
127
.coderabbit.yaml
Normal file
127
.coderabbit.yaml
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||||
|
language: "en-US"
|
||||||
|
early_access: false
|
||||||
|
tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code."
|
||||||
|
|
||||||
|
reviews:
|
||||||
|
profile: "chill"
|
||||||
|
request_changes_workflow: false
|
||||||
|
high_level_summary: false
|
||||||
|
poem: false
|
||||||
|
review_status: false
|
||||||
|
review_details: false
|
||||||
|
commit_status: true
|
||||||
|
collapse_walkthrough: true
|
||||||
|
changed_files_summary: false
|
||||||
|
sequence_diagrams: false
|
||||||
|
estimate_code_review_effort: false
|
||||||
|
assess_linked_issues: false
|
||||||
|
related_issues: false
|
||||||
|
related_prs: false
|
||||||
|
suggested_labels: false
|
||||||
|
auto_apply_labels: false
|
||||||
|
suggested_reviewers: false
|
||||||
|
auto_assign_reviewers: false
|
||||||
|
in_progress_fortune: false
|
||||||
|
enable_prompt_for_ai_agents: true
|
||||||
|
|
||||||
|
path_filters:
|
||||||
|
- "!comfy_api_nodes/apis/**"
|
||||||
|
- "!**/generated/*.pyi"
|
||||||
|
- "!.ci/**"
|
||||||
|
- "!script_examples/**"
|
||||||
|
- "!**/__pycache__/**"
|
||||||
|
- "!**/*.ipynb"
|
||||||
|
- "!**/*.png"
|
||||||
|
- "!**/*.bat"
|
||||||
|
|
||||||
|
path_instructions:
|
||||||
|
- path: "**"
|
||||||
|
instructions: |
|
||||||
|
IMPORTANT: Only comment on issues directly introduced by this PR's code changes.
|
||||||
|
Do NOT flag pre-existing issues in code that was merely moved, re-indented,
|
||||||
|
de-indented, or reformatted without logic changes. If code appears in the diff
|
||||||
|
only due to whitespace or structural reformatting (e.g., removing a `with:` block),
|
||||||
|
treat it as unchanged. Contributors should not feel obligated to address
|
||||||
|
pre-existing issues outside the scope of their contribution.
|
||||||
|
- path: "comfy/**"
|
||||||
|
instructions: |
|
||||||
|
Core ML/diffusion engine. Focus on:
|
||||||
|
- Backward compatibility (breaking changes affect all custom nodes)
|
||||||
|
- Memory management and GPU resource handling
|
||||||
|
- Performance implications in hot paths
|
||||||
|
- Thread safety for concurrent execution
|
||||||
|
- path: "comfy_api_nodes/**"
|
||||||
|
instructions: |
|
||||||
|
Third-party API integration nodes. Focus on:
|
||||||
|
- No hardcoded API keys or secrets
|
||||||
|
- Proper error handling for API failures (timeouts, rate limits, auth errors)
|
||||||
|
- Correct Pydantic model usage
|
||||||
|
- Security of user data passed to external APIs
|
||||||
|
- path: "comfy_extras/**"
|
||||||
|
instructions: |
|
||||||
|
Community-contributed extra nodes. Focus on:
|
||||||
|
- Consistency with node patterns (INPUT_TYPES, RETURN_TYPES, FUNCTION, CATEGORY)
|
||||||
|
- No breaking changes to existing node interfaces
|
||||||
|
- path: "comfy_execution/**"
|
||||||
|
instructions: |
|
||||||
|
Execution engine (graph execution, caching, jobs). Focus on:
|
||||||
|
- Caching correctness
|
||||||
|
- Concurrent execution safety
|
||||||
|
- Graph validation edge cases
|
||||||
|
- path: "nodes.py"
|
||||||
|
instructions: |
|
||||||
|
Core node definitions (2500+ lines). Focus on:
|
||||||
|
- Backward compatibility of NODE_CLASS_MAPPINGS
|
||||||
|
- Consistency of INPUT_TYPES return format
|
||||||
|
- path: "alembic_db/**"
|
||||||
|
instructions: |
|
||||||
|
Database migrations. Focus on:
|
||||||
|
- Migration safety and rollback support
|
||||||
|
- Data preservation during schema changes
|
||||||
|
|
||||||
|
auto_review:
|
||||||
|
enabled: true
|
||||||
|
auto_incremental_review: true
|
||||||
|
drafts: false
|
||||||
|
ignore_title_keywords:
|
||||||
|
- "WIP"
|
||||||
|
- "DO NOT REVIEW"
|
||||||
|
- "DO NOT MERGE"
|
||||||
|
|
||||||
|
finishing_touches:
|
||||||
|
docstrings:
|
||||||
|
enabled: false
|
||||||
|
unit_tests:
|
||||||
|
enabled: false
|
||||||
|
|
||||||
|
tools:
|
||||||
|
ruff:
|
||||||
|
enabled: false
|
||||||
|
pylint:
|
||||||
|
enabled: false
|
||||||
|
flake8:
|
||||||
|
enabled: false
|
||||||
|
gitleaks:
|
||||||
|
enabled: true
|
||||||
|
shellcheck:
|
||||||
|
enabled: false
|
||||||
|
markdownlint:
|
||||||
|
enabled: false
|
||||||
|
yamllint:
|
||||||
|
enabled: false
|
||||||
|
languagetool:
|
||||||
|
enabled: false
|
||||||
|
github-checks:
|
||||||
|
enabled: true
|
||||||
|
timeout_ms: 90000
|
||||||
|
ast-grep:
|
||||||
|
essential_rules: true
|
||||||
|
|
||||||
|
chat:
|
||||||
|
auto_reply: true
|
||||||
|
|
||||||
|
knowledge_base:
|
||||||
|
opt_out: false
|
||||||
|
learnings:
|
||||||
|
scope: "auto"
|
||||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -11,7 +11,7 @@ extra_model_paths.yaml
|
|||||||
/.vs
|
/.vs
|
||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
venv/
|
venv*/
|
||||||
.venv/
|
.venv/
|
||||||
/web/extensions/*
|
/web/extensions/*
|
||||||
!/web/extensions/logging.js.example
|
!/web/extensions/logging.js.example
|
||||||
|
|||||||
@ -227,11 +227,11 @@ Put your VAE in: models/vae
|
|||||||
|
|
||||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.2```
|
||||||
|
|
||||||
|
|
||||||
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
||||||
|
|||||||
107
app/node_replace_manager.py
Normal file
107
app/node_replace_manager.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, TypedDict
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy_api.latest._io_public import NodeReplace
|
||||||
|
|
||||||
|
from comfy_execution.graph_utils import is_link
|
||||||
|
import nodes
|
||||||
|
|
||||||
|
class NodeStruct(TypedDict):
|
||||||
|
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
||||||
|
class_type: str
|
||||||
|
_meta: dict[str, str]
|
||||||
|
|
||||||
|
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
|
||||||
|
new_node_struct = node_struct.copy()
|
||||||
|
if empty_inputs:
|
||||||
|
new_node_struct["inputs"] = {}
|
||||||
|
else:
|
||||||
|
new_node_struct["inputs"] = node_struct["inputs"].copy()
|
||||||
|
new_node_struct["_meta"] = node_struct["_meta"].copy()
|
||||||
|
return new_node_struct
|
||||||
|
|
||||||
|
|
||||||
|
class NodeReplaceManager:
|
||||||
|
"""Manages node replacement registrations."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||||
|
|
||||||
|
def register(self, node_replace: NodeReplace):
|
||||||
|
"""Register a node replacement mapping."""
|
||||||
|
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||||
|
|
||||||
|
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||||
|
"""Get replacements for an old node ID."""
|
||||||
|
return self._replacements.get(old_node_id)
|
||||||
|
|
||||||
|
def has_replacement(self, old_node_id: str) -> bool:
|
||||||
|
"""Check if a replacement exists for an old node ID."""
|
||||||
|
return old_node_id in self._replacements
|
||||||
|
|
||||||
|
def apply_replacements(self, prompt: dict[str, NodeStruct]):
|
||||||
|
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||||
|
need_replacement: set[str] = set()
|
||||||
|
for node_number, node_struct in prompt.items():
|
||||||
|
if "class_type" not in node_struct or "inputs" not in node_struct:
|
||||||
|
continue
|
||||||
|
class_type = node_struct["class_type"]
|
||||||
|
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
||||||
|
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
||||||
|
need_replacement.add(node_number)
|
||||||
|
# keep track of connections
|
||||||
|
for input_id, input_value in node_struct["inputs"].items():
|
||||||
|
if is_link(input_value):
|
||||||
|
conn_number = input_value[0]
|
||||||
|
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
|
||||||
|
for node_number in need_replacement:
|
||||||
|
node_struct = prompt[node_number]
|
||||||
|
class_type = node_struct["class_type"]
|
||||||
|
replacements = self.get_replacement(class_type)
|
||||||
|
if replacements is None:
|
||||||
|
continue
|
||||||
|
# just use the first replacement
|
||||||
|
replacement = replacements[0]
|
||||||
|
new_node_id = replacement.new_node_id
|
||||||
|
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
|
||||||
|
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
|
||||||
|
continue
|
||||||
|
# first, replace node id (class_type)
|
||||||
|
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
|
||||||
|
new_node_struct["class_type"] = new_node_id
|
||||||
|
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
|
||||||
|
# second, replace inputs
|
||||||
|
if replacement.input_mapping is not None:
|
||||||
|
for input_map in replacement.input_mapping:
|
||||||
|
if "set_value" in input_map:
|
||||||
|
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
|
||||||
|
elif "old_id" in input_map:
|
||||||
|
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
|
||||||
|
# finalize input replacement
|
||||||
|
prompt[node_number] = new_node_struct
|
||||||
|
# third, replace outputs
|
||||||
|
if replacement.output_mapping is not None:
|
||||||
|
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
|
||||||
|
if node_number in connections:
|
||||||
|
for conns in connections[node_number]:
|
||||||
|
conn_node_number, conn_input_id, old_output_idx = conns
|
||||||
|
for output_map in replacement.output_mapping:
|
||||||
|
if output_map["old_idx"] == old_output_idx:
|
||||||
|
new_output_idx = output_map["new_idx"]
|
||||||
|
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
|
||||||
|
previous_input[1] = new_output_idx
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
"""Serialize all replacements to dict."""
|
||||||
|
return {
|
||||||
|
k: [v.as_dict() for v in v_list]
|
||||||
|
for k, v_list in self._replacements.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def add_routes(self, routes):
|
||||||
|
@routes.get("/node_replacements")
|
||||||
|
async def get_node_replacements(request):
|
||||||
|
return web.json_response(self.as_dict())
|
||||||
@ -53,7 +53,7 @@ class SubgraphManager:
|
|||||||
return entry_id, entry
|
return entry_id, entry
|
||||||
|
|
||||||
async def load_entry_data(self, entry: SubgraphEntry):
|
async def load_entry_data(self, entry: SubgraphEntry):
|
||||||
with open(entry['path'], 'r') as f:
|
with open(entry['path'], 'r', encoding='utf-8') as f:
|
||||||
entry['data'] = f.read()
|
entry['data'] = f.read()
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
|
|||||||
44
blueprints/.glsl/Brightness_and_Contrast_1.frag
Normal file
44
blueprints/.glsl/Brightness_and_Contrast_1.frag
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform float u_float0; // Brightness slider -100..100
|
||||||
|
uniform float u_float1; // Contrast slider -100..100
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
const float MID_GRAY = 0.18; // 18% reflectance
|
||||||
|
|
||||||
|
// sRGB gamma 2.2 approximation
|
||||||
|
vec3 srgbToLinear(vec3 c) {
|
||||||
|
return pow(max(c, 0.0), vec3(2.2));
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 linearToSrgb(vec3 c) {
|
||||||
|
return pow(max(c, 0.0), vec3(1.0/2.2));
|
||||||
|
}
|
||||||
|
|
||||||
|
float mapBrightness(float b) {
|
||||||
|
return clamp(b / 100.0, -1.0, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
float mapContrast(float c) {
|
||||||
|
return clamp(c / 100.0 + 1.0, 0.0, 2.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 orig = texture(u_image0, v_texCoord);
|
||||||
|
|
||||||
|
float brightness = mapBrightness(u_float0);
|
||||||
|
float contrast = mapContrast(u_float1);
|
||||||
|
|
||||||
|
vec3 lin = srgbToLinear(orig.rgb);
|
||||||
|
|
||||||
|
lin = (lin - MID_GRAY) * contrast + brightness + MID_GRAY;
|
||||||
|
|
||||||
|
// Convert back to sRGB
|
||||||
|
vec3 result = linearToSrgb(clamp(lin, 0.0, 1.0));
|
||||||
|
|
||||||
|
fragColor = vec4(result, orig.a);
|
||||||
|
}
|
||||||
72
blueprints/.glsl/Chromatic_Aberration_16.frag
Normal file
72
blueprints/.glsl/Chromatic_Aberration_16.frag
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform vec2 u_resolution;
|
||||||
|
uniform int u_int0; // Mode
|
||||||
|
uniform float u_float0; // Amount (0 to 100)
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
const int MODE_LINEAR = 0;
|
||||||
|
const int MODE_RADIAL = 1;
|
||||||
|
const int MODE_BARREL = 2;
|
||||||
|
const int MODE_SWIRL = 3;
|
||||||
|
const int MODE_DIAGONAL = 4;
|
||||||
|
|
||||||
|
const float AMOUNT_SCALE = 0.0005;
|
||||||
|
const float RADIAL_MULT = 4.0;
|
||||||
|
const float BARREL_MULT = 8.0;
|
||||||
|
const float INV_SQRT2 = 0.70710678118;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec2 uv = v_texCoord;
|
||||||
|
vec4 original = texture(u_image0, uv);
|
||||||
|
|
||||||
|
float amount = u_float0 * AMOUNT_SCALE;
|
||||||
|
|
||||||
|
if (amount < 0.000001) {
|
||||||
|
fragColor = original;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aspect-corrected coordinates for circular effects
|
||||||
|
float aspect = u_resolution.x / u_resolution.y;
|
||||||
|
vec2 centered = uv - 0.5;
|
||||||
|
vec2 corrected = vec2(centered.x * aspect, centered.y);
|
||||||
|
float r = length(corrected);
|
||||||
|
vec2 dir = r > 0.0001 ? corrected / r : vec2(0.0);
|
||||||
|
vec2 offset = vec2(0.0);
|
||||||
|
|
||||||
|
if (u_int0 == MODE_LINEAR) {
|
||||||
|
// Horizontal shift (no aspect correction needed)
|
||||||
|
offset = vec2(amount, 0.0);
|
||||||
|
}
|
||||||
|
else if (u_int0 == MODE_RADIAL) {
|
||||||
|
// Outward from center, stronger at edges
|
||||||
|
offset = dir * r * amount * RADIAL_MULT;
|
||||||
|
offset.x /= aspect; // Convert back to UV space
|
||||||
|
}
|
||||||
|
else if (u_int0 == MODE_BARREL) {
|
||||||
|
// Lens distortion simulation (r² falloff)
|
||||||
|
offset = dir * r * r * amount * BARREL_MULT;
|
||||||
|
offset.x /= aspect; // Convert back to UV space
|
||||||
|
}
|
||||||
|
else if (u_int0 == MODE_SWIRL) {
|
||||||
|
// Perpendicular to radial (rotational aberration)
|
||||||
|
vec2 perp = vec2(-dir.y, dir.x);
|
||||||
|
offset = perp * r * amount * RADIAL_MULT;
|
||||||
|
offset.x /= aspect; // Convert back to UV space
|
||||||
|
}
|
||||||
|
else if (u_int0 == MODE_DIAGONAL) {
|
||||||
|
// 45° offset (no aspect correction needed)
|
||||||
|
offset = vec2(amount, amount) * INV_SQRT2;
|
||||||
|
}
|
||||||
|
|
||||||
|
float red = texture(u_image0, uv + offset).r;
|
||||||
|
float green = original.g;
|
||||||
|
float blue = texture(u_image0, uv - offset).b;
|
||||||
|
|
||||||
|
fragColor = vec4(red, green, blue, original.a);
|
||||||
|
}
|
||||||
78
blueprints/.glsl/Color_Adjustment_15.frag
Normal file
78
blueprints/.glsl/Color_Adjustment_15.frag
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform float u_float0; // temperature (-100 to 100)
|
||||||
|
uniform float u_float1; // tint (-100 to 100)
|
||||||
|
uniform float u_float2; // vibrance (-100 to 100)
|
||||||
|
uniform float u_float3; // saturation (-100 to 100)
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
const float INPUT_SCALE = 0.01;
|
||||||
|
const float TEMP_TINT_PRIMARY = 0.3;
|
||||||
|
const float TEMP_TINT_SECONDARY = 0.15;
|
||||||
|
const float VIBRANCE_BOOST = 2.0;
|
||||||
|
const float SATURATION_BOOST = 2.0;
|
||||||
|
const float SKIN_PROTECTION = 0.5;
|
||||||
|
const float EPSILON = 0.001;
|
||||||
|
const vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 tex = texture(u_image0, v_texCoord);
|
||||||
|
vec3 color = tex.rgb;
|
||||||
|
|
||||||
|
// Scale inputs: -100/100 → -1/1
|
||||||
|
float temperature = u_float0 * INPUT_SCALE;
|
||||||
|
float tint = u_float1 * INPUT_SCALE;
|
||||||
|
float vibrance = u_float2 * INPUT_SCALE;
|
||||||
|
float saturation = u_float3 * INPUT_SCALE;
|
||||||
|
|
||||||
|
// Temperature (warm/cool): positive = warm, negative = cool
|
||||||
|
color.r += temperature * TEMP_TINT_PRIMARY;
|
||||||
|
color.b -= temperature * TEMP_TINT_PRIMARY;
|
||||||
|
|
||||||
|
// Tint (green/magenta): positive = green, negative = magenta
|
||||||
|
color.g += tint * TEMP_TINT_PRIMARY;
|
||||||
|
color.r -= tint * TEMP_TINT_SECONDARY;
|
||||||
|
color.b -= tint * TEMP_TINT_SECONDARY;
|
||||||
|
|
||||||
|
// Single clamp after temperature/tint
|
||||||
|
color = clamp(color, 0.0, 1.0);
|
||||||
|
|
||||||
|
// Vibrance with skin protection
|
||||||
|
if (vibrance != 0.0) {
|
||||||
|
float maxC = max(color.r, max(color.g, color.b));
|
||||||
|
float minC = min(color.r, min(color.g, color.b));
|
||||||
|
float sat = maxC - minC;
|
||||||
|
float gray = dot(color, LUMA_WEIGHTS);
|
||||||
|
|
||||||
|
if (vibrance < 0.0) {
|
||||||
|
// Desaturate: -100 → gray
|
||||||
|
color = mix(vec3(gray), color, 1.0 + vibrance);
|
||||||
|
} else {
|
||||||
|
// Boost less saturated colors more
|
||||||
|
float vibranceAmt = vibrance * (1.0 - sat);
|
||||||
|
|
||||||
|
// Branchless skin tone protection
|
||||||
|
float isWarmTone = step(color.b, color.g) * step(color.g, color.r);
|
||||||
|
float warmth = (color.r - color.b) / max(maxC, EPSILON);
|
||||||
|
float skinTone = isWarmTone * warmth * sat * (1.0 - sat);
|
||||||
|
vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);
|
||||||
|
|
||||||
|
color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Saturation
|
||||||
|
if (saturation != 0.0) {
|
||||||
|
float gray = dot(color, LUMA_WEIGHTS);
|
||||||
|
float satMix = saturation < 0.0
|
||||||
|
? 1.0 + saturation // -100 → gray
|
||||||
|
: 1.0 + saturation * SATURATION_BOOST; // +100 → 3x boost
|
||||||
|
color = mix(vec3(gray), color, satMix);
|
||||||
|
}
|
||||||
|
|
||||||
|
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
|
||||||
|
}
|
||||||
94
blueprints/.glsl/Edge-Preserving_Blur_128.frag
Normal file
94
blueprints/.glsl/Edge-Preserving_Blur_128.frag
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform float u_float0; // Blur radius (0–20, default ~5)
|
||||||
|
uniform float u_float1; // Edge threshold (0–100, default ~30)
|
||||||
|
uniform int u_int0; // Step size (0/1 = every pixel, 2+ = skip pixels)
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
const int MAX_RADIUS = 20;
|
||||||
|
const float EPSILON = 0.0001;
|
||||||
|
|
||||||
|
// Perceptual luminance
|
||||||
|
float getLuminance(vec3 rgb) {
|
||||||
|
return dot(rgb, vec3(0.299, 0.587, 0.114));
|
||||||
|
}
|
||||||
|
|
||||||
|
vec4 bilateralFilter(vec2 uv, vec2 texelSize, int radius,
|
||||||
|
float sigmaSpatial, float sigmaColor)
|
||||||
|
{
|
||||||
|
vec4 center = texture(u_image0, uv);
|
||||||
|
vec3 centerRGB = center.rgb;
|
||||||
|
|
||||||
|
float invSpatial2 = -0.5 / (sigmaSpatial * sigmaSpatial);
|
||||||
|
float invColor2 = -0.5 / (sigmaColor * sigmaColor + EPSILON);
|
||||||
|
|
||||||
|
vec3 sumRGB = vec3(0.0);
|
||||||
|
float sumWeight = 0.0;
|
||||||
|
|
||||||
|
int step = max(u_int0, 1);
|
||||||
|
float radius2 = float(radius * radius);
|
||||||
|
|
||||||
|
for (int dy = -MAX_RADIUS; dy <= MAX_RADIUS; dy++) {
|
||||||
|
if (dy < -radius || dy > radius) continue;
|
||||||
|
if (abs(dy) % step != 0) continue;
|
||||||
|
|
||||||
|
for (int dx = -MAX_RADIUS; dx <= MAX_RADIUS; dx++) {
|
||||||
|
if (dx < -radius || dx > radius) continue;
|
||||||
|
if (abs(dx) % step != 0) continue;
|
||||||
|
|
||||||
|
vec2 offset = vec2(float(dx), float(dy));
|
||||||
|
float dist2 = dot(offset, offset);
|
||||||
|
if (dist2 > radius2) continue;
|
||||||
|
|
||||||
|
vec3 sampleRGB = texture(u_image0, uv + offset * texelSize).rgb;
|
||||||
|
|
||||||
|
// Spatial Gaussian
|
||||||
|
float spatialWeight = exp(dist2 * invSpatial2);
|
||||||
|
|
||||||
|
// Perceptual color distance (weighted RGB)
|
||||||
|
vec3 diff = sampleRGB - centerRGB;
|
||||||
|
float colorDist = dot(diff * diff, vec3(0.299, 0.587, 0.114));
|
||||||
|
float colorWeight = exp(colorDist * invColor2);
|
||||||
|
|
||||||
|
float w = spatialWeight * colorWeight;
|
||||||
|
sumRGB += sampleRGB * w;
|
||||||
|
sumWeight += w;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 resultRGB = sumRGB / max(sumWeight, EPSILON);
|
||||||
|
return vec4(resultRGB, center.a); // preserve center alpha
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
|
||||||
|
|
||||||
|
float radiusF = clamp(u_float0, 0.0, float(MAX_RADIUS));
|
||||||
|
int radius = int(radiusF + 0.5);
|
||||||
|
|
||||||
|
if (radius == 0) {
|
||||||
|
fragColor = texture(u_image0, v_texCoord);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Edge threshold → color sigma
|
||||||
|
// Squared curve for better low-end control
|
||||||
|
float t = clamp(u_float1, 0.0, 100.0) / 100.0;
|
||||||
|
t *= t;
|
||||||
|
float sigmaColor = mix(0.01, 0.5, t);
|
||||||
|
|
||||||
|
// Spatial sigma tied to radius
|
||||||
|
float sigmaSpatial = max(radiusF * 0.75, 0.5);
|
||||||
|
|
||||||
|
fragColor = bilateralFilter(
|
||||||
|
v_texCoord,
|
||||||
|
texelSize,
|
||||||
|
radius,
|
||||||
|
sigmaSpatial,
|
||||||
|
sigmaColor
|
||||||
|
);
|
||||||
|
}
|
||||||
124
blueprints/.glsl/Film_Grain_15.frag
Normal file
124
blueprints/.glsl/Film_Grain_15.frag
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform vec2 u_resolution;
|
||||||
|
uniform float u_float0; // grain amount [0.0 – 1.0] typical: 0.2–0.8
|
||||||
|
uniform float u_float1; // grain size [0.3 – 3.0] lower = finer grain
|
||||||
|
uniform float u_float2; // color amount [0.0 – 1.0] 0 = monochrome, 1 = RGB grain
|
||||||
|
uniform float u_float3; // luminance bias [0.0 – 1.0] 0 = uniform, 1 = shadows only
|
||||||
|
uniform int u_int0; // noise mode [0 or 1] 0 = smooth, 1 = grainy
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
layout(location = 0) out vec4 fragColor0;
|
||||||
|
|
||||||
|
// High-quality integer hash (pcg-like)
|
||||||
|
uint pcg(uint v) {
|
||||||
|
uint state = v * 747796405u + 2891336453u;
|
||||||
|
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
|
||||||
|
return (word >> 22u) ^ word;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2D -> 1D hash input
|
||||||
|
uint hash2d(uvec2 p) {
|
||||||
|
return pcg(p.x + pcg(p.y));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash to float [0, 1]
|
||||||
|
float hashf(uvec2 p) {
|
||||||
|
return float(hash2d(p)) / float(0xffffffffu);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash to float with offset (for RGB channels)
|
||||||
|
float hashf(uvec2 p, uint offset) {
|
||||||
|
return float(pcg(hash2d(p) + offset)) / float(0xffffffffu);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert uniform [0,1] to roughly Gaussian distribution
|
||||||
|
// Using simple approximation: average of multiple samples
|
||||||
|
float toGaussian(uvec2 p) {
|
||||||
|
float sum = hashf(p, 0u) + hashf(p, 1u) + hashf(p, 2u) + hashf(p, 3u);
|
||||||
|
return (sum - 2.0) * 0.7; // Centered, scaled
|
||||||
|
}
|
||||||
|
|
||||||
|
float toGaussian(uvec2 p, uint offset) {
|
||||||
|
float sum = hashf(p, offset) + hashf(p, offset + 1u)
|
||||||
|
+ hashf(p, offset + 2u) + hashf(p, offset + 3u);
|
||||||
|
return (sum - 2.0) * 0.7;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Smooth noise with better interpolation
|
||||||
|
float smoothNoise(vec2 p) {
|
||||||
|
vec2 i = floor(p);
|
||||||
|
vec2 f = fract(p);
|
||||||
|
|
||||||
|
// Quintic interpolation (less banding than cubic)
|
||||||
|
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
||||||
|
|
||||||
|
uvec2 ui = uvec2(i);
|
||||||
|
float a = toGaussian(ui);
|
||||||
|
float b = toGaussian(ui + uvec2(1u, 0u));
|
||||||
|
float c = toGaussian(ui + uvec2(0u, 1u));
|
||||||
|
float d = toGaussian(ui + uvec2(1u, 1u));
|
||||||
|
|
||||||
|
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
float smoothNoise(vec2 p, uint offset) {
|
||||||
|
vec2 i = floor(p);
|
||||||
|
vec2 f = fract(p);
|
||||||
|
|
||||||
|
f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);
|
||||||
|
|
||||||
|
uvec2 ui = uvec2(i);
|
||||||
|
float a = toGaussian(ui, offset);
|
||||||
|
float b = toGaussian(ui + uvec2(1u, 0u), offset);
|
||||||
|
float c = toGaussian(ui + uvec2(0u, 1u), offset);
|
||||||
|
float d = toGaussian(ui + uvec2(1u, 1u), offset);
|
||||||
|
|
||||||
|
return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 color = texture(u_image0, v_texCoord);
|
||||||
|
|
||||||
|
// Luminance (Rec.709)
|
||||||
|
float luma = dot(color.rgb, vec3(0.2126, 0.7152, 0.0722));
|
||||||
|
|
||||||
|
// Grain UV (resolution-independent)
|
||||||
|
vec2 grainUV = v_texCoord * u_resolution / max(u_float1, 0.01);
|
||||||
|
uvec2 grainPixel = uvec2(grainUV);
|
||||||
|
|
||||||
|
float g;
|
||||||
|
vec3 grainRGB;
|
||||||
|
|
||||||
|
if (u_int0 == 1) {
|
||||||
|
// Grainy mode: pure hash noise (no interpolation = no banding)
|
||||||
|
g = toGaussian(grainPixel);
|
||||||
|
grainRGB = vec3(
|
||||||
|
toGaussian(grainPixel, 100u),
|
||||||
|
toGaussian(grainPixel, 200u),
|
||||||
|
toGaussian(grainPixel, 300u)
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// Smooth mode: interpolated with quintic curve
|
||||||
|
g = smoothNoise(grainUV);
|
||||||
|
grainRGB = vec3(
|
||||||
|
smoothNoise(grainUV, 100u),
|
||||||
|
smoothNoise(grainUV, 200u),
|
||||||
|
smoothNoise(grainUV, 300u)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Luminance weighting (less grain in highlights)
|
||||||
|
float lumWeight = mix(1.0, 1.0 - luma, clamp(u_float3, 0.0, 1.0));
|
||||||
|
|
||||||
|
// Strength
|
||||||
|
float strength = u_float0 * 0.15;
|
||||||
|
|
||||||
|
// Color vs monochrome grain
|
||||||
|
vec3 grainColor = mix(vec3(g), grainRGB, clamp(u_float2, 0.0, 1.0));
|
||||||
|
|
||||||
|
color.rgb += grainColor * strength * lumWeight;
|
||||||
|
fragColor0 = vec4(clamp(color.rgb, 0.0, 1.0), color.a);
|
||||||
|
}
|
||||||
133
blueprints/.glsl/Glow_30.frag
Normal file
133
blueprints/.glsl/Glow_30.frag
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision mediump float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform vec2 u_resolution;
|
||||||
|
uniform int u_int0; // Blend mode
|
||||||
|
uniform int u_int1; // Color tint
|
||||||
|
uniform float u_float0; // Intensity
|
||||||
|
uniform float u_float1; // Radius
|
||||||
|
uniform float u_float2; // Threshold
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
const int BLEND_ADD = 0;
|
||||||
|
const int BLEND_SCREEN = 1;
|
||||||
|
const int BLEND_SOFT = 2;
|
||||||
|
const int BLEND_OVERLAY = 3;
|
||||||
|
const int BLEND_LIGHTEN = 4;
|
||||||
|
|
||||||
|
const float GOLDEN_ANGLE = 2.39996323;
|
||||||
|
const int MAX_SAMPLES = 48;
|
||||||
|
const vec3 LUMA = vec3(0.299, 0.587, 0.114);
|
||||||
|
|
||||||
|
float hash(vec2 p) {
|
||||||
|
p = fract(p * vec2(123.34, 456.21));
|
||||||
|
p += dot(p, p + 45.32);
|
||||||
|
return fract(p.x * p.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 hexToRgb(int h) {
|
||||||
|
return vec3(
|
||||||
|
float((h >> 16) & 255),
|
||||||
|
float((h >> 8) & 255),
|
||||||
|
float(h & 255)
|
||||||
|
) * (1.0 / 255.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 blend(vec3 base, vec3 glow, int mode) {
|
||||||
|
if (mode == BLEND_SCREEN) {
|
||||||
|
return 1.0 - (1.0 - base) * (1.0 - glow);
|
||||||
|
}
|
||||||
|
if (mode == BLEND_SOFT) {
|
||||||
|
return mix(
|
||||||
|
base - (1.0 - 2.0 * glow) * base * (1.0 - base),
|
||||||
|
base + (2.0 * glow - 1.0) * (sqrt(base) - base),
|
||||||
|
step(0.5, glow)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (mode == BLEND_OVERLAY) {
|
||||||
|
return mix(
|
||||||
|
2.0 * base * glow,
|
||||||
|
1.0 - 2.0 * (1.0 - base) * (1.0 - glow),
|
||||||
|
step(0.5, base)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (mode == BLEND_LIGHTEN) {
|
||||||
|
return max(base, glow);
|
||||||
|
}
|
||||||
|
return base + glow;
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 original = texture(u_image0, v_texCoord);
|
||||||
|
|
||||||
|
float intensity = u_float0 * 0.05;
|
||||||
|
float radius = u_float1 * u_float1 * 0.012;
|
||||||
|
|
||||||
|
if (intensity < 0.001 || radius < 0.1) {
|
||||||
|
fragColor = original;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float threshold = 1.0 - u_float2 * 0.01;
|
||||||
|
float t0 = threshold - 0.15;
|
||||||
|
float t1 = threshold + 0.15;
|
||||||
|
|
||||||
|
vec2 texelSize = 1.0 / u_resolution;
|
||||||
|
float radius2 = radius * radius;
|
||||||
|
|
||||||
|
float sampleScale = clamp(radius * 0.75, 0.35, 1.0);
|
||||||
|
int samples = int(float(MAX_SAMPLES) * sampleScale);
|
||||||
|
|
||||||
|
float noise = hash(gl_FragCoord.xy);
|
||||||
|
float angleOffset = noise * GOLDEN_ANGLE;
|
||||||
|
float radiusJitter = 0.85 + noise * 0.3;
|
||||||
|
|
||||||
|
float ca = cos(GOLDEN_ANGLE);
|
||||||
|
float sa = sin(GOLDEN_ANGLE);
|
||||||
|
vec2 dir = vec2(cos(angleOffset), sin(angleOffset));
|
||||||
|
|
||||||
|
vec3 glow = vec3(0.0);
|
||||||
|
float totalWeight = 0.0;
|
||||||
|
|
||||||
|
// Center tap
|
||||||
|
float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));
|
||||||
|
glow += original.rgb * centerMask * 2.0;
|
||||||
|
totalWeight += 2.0;
|
||||||
|
|
||||||
|
for (int i = 1; i < MAX_SAMPLES; i++) {
|
||||||
|
if (i >= samples) break;
|
||||||
|
|
||||||
|
float fi = float(i);
|
||||||
|
float dist = sqrt(fi / float(samples)) * radius * radiusJitter;
|
||||||
|
|
||||||
|
vec2 offset = dir * dist * texelSize;
|
||||||
|
vec3 c = texture(u_image0, v_texCoord + offset).rgb;
|
||||||
|
float mask = smoothstep(t0, t1, dot(c, LUMA));
|
||||||
|
|
||||||
|
float w = 1.0 - (dist * dist) / (radius2 * 1.5);
|
||||||
|
w = max(w, 0.0);
|
||||||
|
w *= w;
|
||||||
|
|
||||||
|
glow += c * mask * w;
|
||||||
|
totalWeight += w;
|
||||||
|
|
||||||
|
dir = vec2(
|
||||||
|
dir.x * ca - dir.y * sa,
|
||||||
|
dir.x * sa + dir.y * ca
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
glow *= intensity / max(totalWeight, 0.001);
|
||||||
|
|
||||||
|
if (u_int1 > 0) {
|
||||||
|
glow *= hexToRgb(u_int1);
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 result = blend(original.rgb, glow, u_int0);
|
||||||
|
result += (noise - 0.5) * (1.0 / 255.0);
|
||||||
|
|
||||||
|
fragColor = vec4(clamp(result, 0.0, 1.0), original.a);
|
||||||
|
}
|
||||||
222
blueprints/.glsl/Hue_and_Saturation_1.frag
Normal file
222
blueprints/.glsl/Hue_and_Saturation_1.frag
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform int u_int0; // Mode: 0=Master, 1=Reds, 2=Yellows, 3=Greens, 4=Cyans, 5=Blues, 6=Magentas, 7=Colorize
|
||||||
|
uniform int u_int1; // Color Space: 0=HSL, 1=HSB/HSV
|
||||||
|
uniform float u_float0; // Hue (-180 to 180)
|
||||||
|
uniform float u_float1; // Saturation (-100 to 100)
|
||||||
|
uniform float u_float2; // Lightness/Brightness (-100 to 100)
|
||||||
|
uniform float u_float3; // Overlap (0 to 100) - feathering between adjacent color ranges
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
// Color range modes
|
||||||
|
const int MODE_MASTER = 0;
|
||||||
|
const int MODE_RED = 1;
|
||||||
|
const int MODE_YELLOW = 2;
|
||||||
|
const int MODE_GREEN = 3;
|
||||||
|
const int MODE_CYAN = 4;
|
||||||
|
const int MODE_BLUE = 5;
|
||||||
|
const int MODE_MAGENTA = 6;
|
||||||
|
const int MODE_COLORIZE = 7;
|
||||||
|
|
||||||
|
// Color space modes
|
||||||
|
const int COLORSPACE_HSL = 0;
|
||||||
|
const int COLORSPACE_HSB = 1;
|
||||||
|
|
||||||
|
const float EPSILON = 0.0001;
|
||||||
|
|
||||||
|
//=============================================================================
|
||||||
|
// RGB <-> HSL Conversions
|
||||||
|
//=============================================================================
|
||||||
|
|
||||||
|
vec3 rgb2hsl(vec3 c) {
|
||||||
|
float maxC = max(max(c.r, c.g), c.b);
|
||||||
|
float minC = min(min(c.r, c.g), c.b);
|
||||||
|
float delta = maxC - minC;
|
||||||
|
|
||||||
|
float h = 0.0;
|
||||||
|
float s = 0.0;
|
||||||
|
float l = (maxC + minC) * 0.5;
|
||||||
|
|
||||||
|
if (delta > EPSILON) {
|
||||||
|
s = l < 0.5
|
||||||
|
? delta / (maxC + minC)
|
||||||
|
: delta / (2.0 - maxC - minC);
|
||||||
|
|
||||||
|
if (maxC == c.r) {
|
||||||
|
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
||||||
|
} else if (maxC == c.g) {
|
||||||
|
h = (c.b - c.r) / delta + 2.0;
|
||||||
|
} else {
|
||||||
|
h = (c.r - c.g) / delta + 4.0;
|
||||||
|
}
|
||||||
|
h /= 6.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return vec3(h, s, l);
|
||||||
|
}
|
||||||
|
|
||||||
|
float hue2rgb(float p, float q, float t) {
|
||||||
|
t = fract(t);
|
||||||
|
if (t < 1.0/6.0) return p + (q - p) * 6.0 * t;
|
||||||
|
if (t < 0.5) return q;
|
||||||
|
if (t < 2.0/3.0) return p + (q - p) * (2.0/3.0 - t) * 6.0;
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 hsl2rgb(vec3 hsl) {
|
||||||
|
if (hsl.y < EPSILON) return vec3(hsl.z);
|
||||||
|
|
||||||
|
float q = hsl.z < 0.5
|
||||||
|
? hsl.z * (1.0 + hsl.y)
|
||||||
|
: hsl.z + hsl.y - hsl.z * hsl.y;
|
||||||
|
float p = 2.0 * hsl.z - q;
|
||||||
|
|
||||||
|
return vec3(
|
||||||
|
hue2rgb(p, q, hsl.x + 1.0/3.0),
|
||||||
|
hue2rgb(p, q, hsl.x),
|
||||||
|
hue2rgb(p, q, hsl.x - 1.0/3.0)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 rgb2hsb(vec3 c) {
|
||||||
|
float maxC = max(max(c.r, c.g), c.b);
|
||||||
|
float minC = min(min(c.r, c.g), c.b);
|
||||||
|
float delta = maxC - minC;
|
||||||
|
|
||||||
|
float h = 0.0;
|
||||||
|
float s = (maxC > EPSILON) ? delta / maxC : 0.0;
|
||||||
|
float b = maxC;
|
||||||
|
|
||||||
|
if (delta > EPSILON) {
|
||||||
|
if (maxC == c.r) {
|
||||||
|
h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);
|
||||||
|
} else if (maxC == c.g) {
|
||||||
|
h = (c.b - c.r) / delta + 2.0;
|
||||||
|
} else {
|
||||||
|
h = (c.r - c.g) / delta + 4.0;
|
||||||
|
}
|
||||||
|
h /= 6.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return vec3(h, s, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 hsb2rgb(vec3 hsb) {
|
||||||
|
vec3 rgb = clamp(abs(mod(hsb.x * 6.0 + vec3(0.0, 4.0, 2.0), 6.0) - 3.0) - 1.0, 0.0, 1.0);
|
||||||
|
return hsb.z * mix(vec3(1.0), rgb, hsb.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
//=============================================================================
|
||||||
|
// Color Range Weight Calculation
|
||||||
|
//=============================================================================
|
||||||
|
|
||||||
|
float hueDistance(float a, float b) {
|
||||||
|
float d = abs(a - b);
|
||||||
|
return min(d, 1.0 - d);
|
||||||
|
}
|
||||||
|
|
||||||
|
float getHueWeight(float hue, float center, float overlap) {
|
||||||
|
float baseWidth = 1.0 / 6.0;
|
||||||
|
float feather = baseWidth * overlap;
|
||||||
|
|
||||||
|
float d = hueDistance(hue, center);
|
||||||
|
|
||||||
|
float inner = baseWidth * 0.5;
|
||||||
|
float outer = inner + feather;
|
||||||
|
|
||||||
|
return 1.0 - smoothstep(inner, outer, d);
|
||||||
|
}
|
||||||
|
|
||||||
|
float getModeWeight(float hue, int mode, float overlap) {
|
||||||
|
if (mode == MODE_MASTER || mode == MODE_COLORIZE) return 1.0;
|
||||||
|
|
||||||
|
if (mode == MODE_RED) {
|
||||||
|
return max(
|
||||||
|
getHueWeight(hue, 0.0, overlap),
|
||||||
|
getHueWeight(hue, 1.0, overlap)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
float center = float(mode - 1) / 6.0;
|
||||||
|
return getHueWeight(hue, center, overlap);
|
||||||
|
}
|
||||||
|
|
||||||
|
//=============================================================================
|
||||||
|
// Adjustment Functions
|
||||||
|
//=============================================================================
|
||||||
|
|
||||||
|
float adjustLightness(float l, float amount) {
|
||||||
|
return amount > 0.0
|
||||||
|
? l + (1.0 - l) * amount
|
||||||
|
: l + l * amount;
|
||||||
|
}
|
||||||
|
|
||||||
|
float adjustBrightness(float b, float amount) {
|
||||||
|
return clamp(b + amount, 0.0, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
float adjustSaturation(float s, float amount) {
|
||||||
|
return amount > 0.0
|
||||||
|
? s + (1.0 - s) * amount
|
||||||
|
: s + s * amount;
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 colorize(vec3 rgb, float hue, float sat, float light) {
|
||||||
|
float lum = dot(rgb, vec3(0.299, 0.587, 0.114));
|
||||||
|
float l = adjustLightness(lum, light);
|
||||||
|
|
||||||
|
vec3 hsl = vec3(fract(hue), clamp(sat, 0.0, 1.0), clamp(l, 0.0, 1.0));
|
||||||
|
return hsl2rgb(hsl);
|
||||||
|
}
|
||||||
|
|
||||||
|
//=============================================================================
|
||||||
|
// Main
|
||||||
|
//=============================================================================
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 original = texture(u_image0, v_texCoord);
|
||||||
|
|
||||||
|
float hueShift = u_float0 / 360.0; // -180..180 -> -0.5..0.5
|
||||||
|
float satAmount = u_float1 / 100.0; // -100..100 -> -1..1
|
||||||
|
float lightAmount= u_float2 / 100.0; // -100..100 -> -1..1
|
||||||
|
float overlap = u_float3 / 100.0; // 0..100 -> 0..1
|
||||||
|
|
||||||
|
vec3 result;
|
||||||
|
|
||||||
|
if (u_int0 == MODE_COLORIZE) {
|
||||||
|
result = colorize(original.rgb, hueShift, satAmount, lightAmount);
|
||||||
|
fragColor = vec4(result, original.a);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 hsx = (u_int1 == COLORSPACE_HSL)
|
||||||
|
? rgb2hsl(original.rgb)
|
||||||
|
: rgb2hsb(original.rgb);
|
||||||
|
|
||||||
|
float weight = getModeWeight(hsx.x, u_int0, overlap);
|
||||||
|
|
||||||
|
if (u_int0 != MODE_MASTER && hsx.y < EPSILON) {
|
||||||
|
weight = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (weight > EPSILON) {
|
||||||
|
float h = fract(hsx.x + hueShift * weight);
|
||||||
|
float s = clamp(adjustSaturation(hsx.y, satAmount * weight), 0.0, 1.0);
|
||||||
|
float v = (u_int1 == COLORSPACE_HSL)
|
||||||
|
? clamp(adjustLightness(hsx.z, lightAmount * weight), 0.0, 1.0)
|
||||||
|
: clamp(adjustBrightness(hsx.z, lightAmount * weight), 0.0, 1.0);
|
||||||
|
|
||||||
|
vec3 adjusted = vec3(h, s, v);
|
||||||
|
result = (u_int1 == COLORSPACE_HSL)
|
||||||
|
? hsl2rgb(adjusted)
|
||||||
|
: hsb2rgb(adjusted);
|
||||||
|
} else {
|
||||||
|
result = original.rgb;
|
||||||
|
}
|
||||||
|
|
||||||
|
fragColor = vec4(result, original.a);
|
||||||
|
}
|
||||||
111
blueprints/.glsl/Image_Blur_1.frag
Normal file
111
blueprints/.glsl/Image_Blur_1.frag
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
#version 300 es
|
||||||
|
#pragma passes 2
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
// Blur type constants
|
||||||
|
const int BLUR_GAUSSIAN = 0;
|
||||||
|
const int BLUR_BOX = 1;
|
||||||
|
const int BLUR_RADIAL = 2;
|
||||||
|
|
||||||
|
// Radial blur config
|
||||||
|
const int RADIAL_SAMPLES = 12;
|
||||||
|
const float RADIAL_STRENGTH = 0.0003;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform vec2 u_resolution;
|
||||||
|
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
|
||||||
|
uniform float u_float0; // Blur radius/amount
|
||||||
|
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
layout(location = 0) out vec4 fragColor0;
|
||||||
|
|
||||||
|
float gaussian(float x, float sigma) {
|
||||||
|
return exp(-(x * x) / (2.0 * sigma * sigma));
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec2 texelSize = 1.0 / u_resolution;
|
||||||
|
float radius = max(u_float0, 0.0);
|
||||||
|
|
||||||
|
// Radial (angular) blur - single pass, doesn't use separable
|
||||||
|
if (u_int0 == BLUR_RADIAL) {
|
||||||
|
// Only execute on first pass
|
||||||
|
if (u_pass > 0) {
|
||||||
|
fragColor0 = texture(u_image0, v_texCoord);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
vec2 center = vec2(0.5);
|
||||||
|
vec2 dir = v_texCoord - center;
|
||||||
|
float dist = length(dir);
|
||||||
|
|
||||||
|
if (dist < 1e-4) {
|
||||||
|
fragColor0 = texture(u_image0, v_texCoord);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
vec4 sum = vec4(0.0);
|
||||||
|
float totalWeight = 0.0;
|
||||||
|
float angleStep = radius * RADIAL_STRENGTH;
|
||||||
|
|
||||||
|
dir /= dist;
|
||||||
|
|
||||||
|
float cosStep = cos(angleStep);
|
||||||
|
float sinStep = sin(angleStep);
|
||||||
|
|
||||||
|
float negAngle = -float(RADIAL_SAMPLES) * angleStep;
|
||||||
|
vec2 rotDir = vec2(
|
||||||
|
dir.x * cos(negAngle) - dir.y * sin(negAngle),
|
||||||
|
dir.x * sin(negAngle) + dir.y * cos(negAngle)
|
||||||
|
);
|
||||||
|
|
||||||
|
for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {
|
||||||
|
vec2 uv = center + rotDir * dist;
|
||||||
|
float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);
|
||||||
|
sum += texture(u_image0, uv) * w;
|
||||||
|
totalWeight += w;
|
||||||
|
|
||||||
|
rotDir = vec2(
|
||||||
|
rotDir.x * cosStep - rotDir.y * sinStep,
|
||||||
|
rotDir.x * sinStep + rotDir.y * cosStep
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fragColor0 = sum / max(totalWeight, 0.001);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Separable Gaussian / Box blur
|
||||||
|
int samples = int(ceil(radius));
|
||||||
|
|
||||||
|
if (samples == 0) {
|
||||||
|
fragColor0 = texture(u_image0, v_texCoord);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direction: pass 0 = horizontal, pass 1 = vertical
|
||||||
|
vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);
|
||||||
|
|
||||||
|
vec4 color = vec4(0.0);
|
||||||
|
float totalWeight = 0.0;
|
||||||
|
float sigma = radius / 2.0;
|
||||||
|
|
||||||
|
for (int i = -samples; i <= samples; i++) {
|
||||||
|
vec2 offset = dir * float(i) * texelSize;
|
||||||
|
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
||||||
|
|
||||||
|
float weight;
|
||||||
|
if (u_int0 == BLUR_GAUSSIAN) {
|
||||||
|
weight = gaussian(float(i), sigma);
|
||||||
|
} else {
|
||||||
|
// BLUR_BOX
|
||||||
|
weight = 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
color += sample_color * weight;
|
||||||
|
totalWeight += weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
fragColor0 = color / totalWeight;
|
||||||
|
}
|
||||||
19
blueprints/.glsl/Image_Channels_23.frag
Normal file
19
blueprints/.glsl/Image_Channels_23.frag
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
layout(location = 0) out vec4 fragColor0;
|
||||||
|
layout(location = 1) out vec4 fragColor1;
|
||||||
|
layout(location = 2) out vec4 fragColor2;
|
||||||
|
layout(location = 3) out vec4 fragColor3;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 color = texture(u_image0, v_texCoord);
|
||||||
|
// Output each channel as grayscale to separate render targets
|
||||||
|
fragColor0 = vec4(vec3(color.r), 1.0); // Red channel
|
||||||
|
fragColor1 = vec4(vec3(color.g), 1.0); // Green channel
|
||||||
|
fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel
|
||||||
|
fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel
|
||||||
|
}
|
||||||
71
blueprints/.glsl/Image_Levels_1.frag
Normal file
71
blueprints/.glsl/Image_Levels_1.frag
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
// Levels Adjustment
|
||||||
|
// u_int0: channel (0=RGB, 1=R, 2=G, 3=B) default: 0
|
||||||
|
// u_float0: input black (0-255) default: 0
|
||||||
|
// u_float1: input white (0-255) default: 255
|
||||||
|
// u_float2: gamma (0.01-9.99) default: 1.0
|
||||||
|
// u_float3: output black (0-255) default: 0
|
||||||
|
// u_float4: output white (0-255) default: 255
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform int u_int0;
|
||||||
|
uniform float u_float0;
|
||||||
|
uniform float u_float1;
|
||||||
|
uniform float u_float2;
|
||||||
|
uniform float u_float3;
|
||||||
|
uniform float u_float4;
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
vec3 applyLevels(vec3 color, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
||||||
|
float inRange = max(inWhite - inBlack, 0.0001);
|
||||||
|
vec3 result = clamp((color - inBlack) / inRange, 0.0, 1.0);
|
||||||
|
result = pow(result, vec3(1.0 / gamma));
|
||||||
|
result = mix(vec3(outBlack), vec3(outWhite), result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
float applySingleChannel(float value, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {
|
||||||
|
float inRange = max(inWhite - inBlack, 0.0001);
|
||||||
|
float result = clamp((value - inBlack) / inRange, 0.0, 1.0);
|
||||||
|
result = pow(result, 1.0 / gamma);
|
||||||
|
result = mix(outBlack, outWhite, result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 texColor = texture(u_image0, v_texCoord);
|
||||||
|
vec3 color = texColor.rgb;
|
||||||
|
|
||||||
|
float inBlack = u_float0 / 255.0;
|
||||||
|
float inWhite = u_float1 / 255.0;
|
||||||
|
float gamma = u_float2;
|
||||||
|
float outBlack = u_float3 / 255.0;
|
||||||
|
float outWhite = u_float4 / 255.0;
|
||||||
|
|
||||||
|
vec3 result;
|
||||||
|
|
||||||
|
if (u_int0 == 0) {
|
||||||
|
result = applyLevels(color, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||||
|
}
|
||||||
|
else if (u_int0 == 1) {
|
||||||
|
result = color;
|
||||||
|
result.r = applySingleChannel(color.r, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||||
|
}
|
||||||
|
else if (u_int0 == 2) {
|
||||||
|
result = color;
|
||||||
|
result.g = applySingleChannel(color.g, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||||
|
}
|
||||||
|
else if (u_int0 == 3) {
|
||||||
|
result = color;
|
||||||
|
result.b = applySingleChannel(color.b, inBlack, inWhite, gamma, outBlack, outWhite);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
result = color;
|
||||||
|
}
|
||||||
|
|
||||||
|
fragColor = vec4(result, texColor.a);
|
||||||
|
}
|
||||||
28
blueprints/.glsl/README.md
Normal file
28
blueprints/.glsl/README.md
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# GLSL Shader Sources
|
||||||
|
|
||||||
|
This folder contains the GLSL fragment shaders extracted from blueprint JSON files for easier editing and version control.
|
||||||
|
|
||||||
|
## File Naming Convention
|
||||||
|
|
||||||
|
`{Blueprint_Name}_{node_id}.frag`
|
||||||
|
|
||||||
|
- **Blueprint_Name**: The JSON filename with spaces/special chars replaced by underscores
|
||||||
|
- **node_id**: The GLSLShader node ID within the subgraph
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Extract shaders from blueprint JSONs to this folder
|
||||||
|
python update_blueprints.py extract
|
||||||
|
|
||||||
|
# Patch edited shaders back into blueprint JSONs
|
||||||
|
python update_blueprints.py patch
|
||||||
|
```
|
||||||
|
|
||||||
|
## Workflow
|
||||||
|
|
||||||
|
1. Run `extract` to pull current shaders from JSONs
|
||||||
|
2. Edit `.frag` files
|
||||||
|
3. Run `patch` to update the blueprint JSONs
|
||||||
|
4. Test
|
||||||
|
5. Commit both `.frag` files and updated JSONs
|
||||||
28
blueprints/.glsl/Sharpen_23.frag
Normal file
28
blueprints/.glsl/Sharpen_23.frag
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform vec2 u_resolution;
|
||||||
|
uniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
layout(location = 0) out vec4 fragColor0;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec2 texel = 1.0 / u_resolution;
|
||||||
|
|
||||||
|
// Sample center and neighbors
|
||||||
|
vec4 center = texture(u_image0, v_texCoord);
|
||||||
|
vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));
|
||||||
|
vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));
|
||||||
|
vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));
|
||||||
|
vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));
|
||||||
|
|
||||||
|
// Edge enhancement (Laplacian)
|
||||||
|
vec4 edges = center * 4.0 - top - bottom - left - right;
|
||||||
|
|
||||||
|
// Add edges back scaled by strength
|
||||||
|
vec4 sharpened = center + edges * u_float0;
|
||||||
|
|
||||||
|
fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);
|
||||||
|
}
|
||||||
61
blueprints/.glsl/Unsharp_Mask_26.frag
Normal file
61
blueprints/.glsl/Unsharp_Mask_26.frag
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform vec2 u_resolution;
|
||||||
|
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
|
||||||
|
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
|
||||||
|
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
layout(location = 0) out vec4 fragColor0;
|
||||||
|
|
||||||
|
float gaussian(float x, float sigma) {
|
||||||
|
return exp(-(x * x) / (2.0 * sigma * sigma));
|
||||||
|
}
|
||||||
|
|
||||||
|
float getLuminance(vec3 color) {
|
||||||
|
return dot(color, vec3(0.2126, 0.7152, 0.0722));
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec2 texel = 1.0 / u_resolution;
|
||||||
|
float radius = max(u_float1, 0.5);
|
||||||
|
float amount = u_float0;
|
||||||
|
float threshold = u_float2;
|
||||||
|
|
||||||
|
vec4 original = texture(u_image0, v_texCoord);
|
||||||
|
|
||||||
|
// Gaussian blur for the "unsharp" mask
|
||||||
|
int samples = int(ceil(radius));
|
||||||
|
float sigma = radius / 2.0;
|
||||||
|
|
||||||
|
vec4 blurred = vec4(0.0);
|
||||||
|
float totalWeight = 0.0;
|
||||||
|
|
||||||
|
for (int x = -samples; x <= samples; x++) {
|
||||||
|
for (int y = -samples; y <= samples; y++) {
|
||||||
|
vec2 offset = vec2(float(x), float(y)) * texel;
|
||||||
|
vec4 sample_color = texture(u_image0, v_texCoord + offset);
|
||||||
|
|
||||||
|
float dist = length(vec2(float(x), float(y)));
|
||||||
|
float weight = gaussian(dist, sigma);
|
||||||
|
blurred += sample_color * weight;
|
||||||
|
totalWeight += weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
blurred /= totalWeight;
|
||||||
|
|
||||||
|
// Unsharp mask = original - blurred
|
||||||
|
vec3 mask = original.rgb - blurred.rgb;
|
||||||
|
|
||||||
|
// Luminance-based threshold with smooth falloff
|
||||||
|
float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));
|
||||||
|
float thresholdScale = smoothstep(0.0, threshold, lumaDelta);
|
||||||
|
mask *= thresholdScale;
|
||||||
|
|
||||||
|
// Sharpen: original + mask * amount
|
||||||
|
vec3 sharpened = original.rgb + mask * amount;
|
||||||
|
|
||||||
|
fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);
|
||||||
|
}
|
||||||
159
blueprints/.glsl/update_blueprints.py
Normal file
159
blueprints/.glsl/update_blueprints.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Shader Blueprint Updater
|
||||||
|
|
||||||
|
Syncs GLSL shader files between this folder and blueprint JSON files.
|
||||||
|
|
||||||
|
File naming convention:
|
||||||
|
{Blueprint Name}_{node_id}.frag
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python update_blueprints.py extract # Extract shaders from JSONs to here
|
||||||
|
python update_blueprints.py patch # Patch shaders back into JSONs
|
||||||
|
python update_blueprints.py # Same as patch (default)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
GLSL_DIR = Path(__file__).parent
|
||||||
|
BLUEPRINTS_DIR = GLSL_DIR.parent
|
||||||
|
|
||||||
|
|
||||||
|
def get_blueprint_files():
|
||||||
|
"""Get all blueprint JSON files."""
|
||||||
|
return sorted(BLUEPRINTS_DIR.glob("*.json"))
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_filename(name):
|
||||||
|
"""Convert blueprint name to safe filename."""
|
||||||
|
return re.sub(r'[^\w\-]', '_', name)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_shaders():
|
||||||
|
"""Extract all shaders from blueprint JSONs to this folder."""
|
||||||
|
extracted = 0
|
||||||
|
for json_path in get_blueprint_files():
|
||||||
|
blueprint_name = json_path.stem
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(json_path, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
except (json.JSONDecodeError, IOError) as e:
|
||||||
|
logger.warning("Skipping %s: %s", json_path.name, e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find GLSLShader nodes in subgraphs
|
||||||
|
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
||||||
|
for node in subgraph.get('nodes', []):
|
||||||
|
if node.get('type') == 'GLSLShader':
|
||||||
|
node_id = node.get('id')
|
||||||
|
widgets = node.get('widgets_values', [])
|
||||||
|
|
||||||
|
# Find shader code (first string that looks like GLSL)
|
||||||
|
for widget in widgets:
|
||||||
|
if isinstance(widget, str) and widget.startswith('#version'):
|
||||||
|
safe_name = sanitize_filename(blueprint_name)
|
||||||
|
frag_name = f"{safe_name}_{node_id}.frag"
|
||||||
|
frag_path = GLSL_DIR / frag_name
|
||||||
|
|
||||||
|
with open(frag_path, 'w') as f:
|
||||||
|
f.write(widget)
|
||||||
|
|
||||||
|
logger.info(" Extracted: %s", frag_name)
|
||||||
|
extracted += 1
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info("\nExtracted %d shader(s)", extracted)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_shaders():
|
||||||
|
"""Patch shaders from this folder back into blueprint JSONs."""
|
||||||
|
# Build lookup: blueprint_name -> [(node_id, shader_code), ...]
|
||||||
|
shader_updates = {}
|
||||||
|
|
||||||
|
for frag_path in sorted(GLSL_DIR.glob("*.frag")):
|
||||||
|
# Parse filename: {blueprint_name}_{node_id}.frag
|
||||||
|
parts = frag_path.stem.rsplit('_', 1)
|
||||||
|
if len(parts) != 2:
|
||||||
|
logger.warning("Skipping %s: invalid filename format", frag_path.name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
blueprint_name, node_id_str = parts
|
||||||
|
|
||||||
|
try:
|
||||||
|
node_id = int(node_id_str)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("Skipping %s: invalid node_id", frag_path.name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
with open(frag_path, 'r') as f:
|
||||||
|
shader_code = f.read()
|
||||||
|
|
||||||
|
if blueprint_name not in shader_updates:
|
||||||
|
shader_updates[blueprint_name] = []
|
||||||
|
shader_updates[blueprint_name].append((node_id, shader_code))
|
||||||
|
|
||||||
|
# Apply updates to JSON files
|
||||||
|
patched = 0
|
||||||
|
for json_path in get_blueprint_files():
|
||||||
|
blueprint_name = sanitize_filename(json_path.stem)
|
||||||
|
|
||||||
|
if blueprint_name not in shader_updates:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(json_path, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
except (json.JSONDecodeError, IOError) as e:
|
||||||
|
logger.error("Error reading %s: %s", json_path.name, e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
modified = False
|
||||||
|
for node_id, shader_code in shader_updates[blueprint_name]:
|
||||||
|
# Find the node and update
|
||||||
|
for subgraph in data.get('definitions', {}).get('subgraphs', []):
|
||||||
|
for node in subgraph.get('nodes', []):
|
||||||
|
if node.get('id') == node_id and node.get('type') == 'GLSLShader':
|
||||||
|
widgets = node.get('widgets_values', [])
|
||||||
|
if len(widgets) > 0 and widgets[0] != shader_code:
|
||||||
|
widgets[0] = shader_code
|
||||||
|
modified = True
|
||||||
|
logger.info(" Patched: %s (node %d)", json_path.name, node_id)
|
||||||
|
patched += 1
|
||||||
|
|
||||||
|
if modified:
|
||||||
|
with open(json_path, 'w') as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
|
||||||
|
if patched == 0:
|
||||||
|
logger.info("No changes to apply.")
|
||||||
|
else:
|
||||||
|
logger.info("\nPatched %d shader(s)", patched)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
command = "patch"
|
||||||
|
else:
|
||||||
|
command = sys.argv[1].lower()
|
||||||
|
|
||||||
|
if command == "extract":
|
||||||
|
logger.info("Extracting shaders from blueprints...")
|
||||||
|
extract_shaders()
|
||||||
|
elif command in ("patch", "update", "apply"):
|
||||||
|
logger.info("Patching shaders into blueprints...")
|
||||||
|
patch_shaders()
|
||||||
|
else:
|
||||||
|
logger.info(__doc__)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
blueprints/Brightness and Contrast.json
Normal file
1
blueprints/Brightness and Contrast.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Canny to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Canny to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Canny to Video (LTX 2.0).json
Normal file
1
blueprints/Canny to Video (LTX 2.0).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Chromatic Aberration.json
Normal file
1
blueprints/Chromatic Aberration.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Color Adjustment.json
Normal file
1
blueprints/Color Adjustment.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Depth to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Depth to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Depth to Video (ltx 2.0).json
Normal file
1
blueprints/Depth to Video (ltx 2.0).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Edge-Preserving Blur.json
Normal file
1
blueprints/Edge-Preserving Blur.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Film Grain.json
Normal file
1
blueprints/Film Grain.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Glow.json
Normal file
1
blueprints/Glow.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Hue and Saturation.json
Normal file
1
blueprints/Hue and Saturation.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Blur.json
Normal file
1
blueprints/Image Blur.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Captioning (gemini).json
Normal file
1
blueprints/Image Captioning (gemini).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Channels.json
Normal file
1
blueprints/Image Channels.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}}
|
||||||
1
blueprints/Image Edit (Flux.2 Klein 4B).json
Normal file
1
blueprints/Image Edit (Flux.2 Klein 4B).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Edit (Qwen 2511).json
Normal file
1
blueprints/Image Edit (Qwen 2511).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Inpainting (Qwen-image).json
Normal file
1
blueprints/Image Inpainting (Qwen-image).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Levels.json
Normal file
1
blueprints/Image Levels.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Outpainting (Qwen-Image).json
Normal file
1
blueprints/Image Outpainting (Qwen-Image).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image Upscale(Z-image-Turbo).json
Normal file
1
blueprints/Image Upscale(Z-image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Depth Map (Lotus).json
Normal file
1
blueprints/Image to Depth Map (Lotus).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Layers(Qwen-Image Layered).json
Normal file
1
blueprints/Image to Layers(Qwen-Image Layered).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Model (Hunyuan3d 2.1).json
Normal file
1
blueprints/Image to Model (Hunyuan3d 2.1).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Image to Video (Wan 2.2).json
Normal file
1
blueprints/Image to Video (Wan 2.2).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Pose to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Pose to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Pose to Video (LTX 2.0).json
Normal file
1
blueprints/Pose to Video (LTX 2.0).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Prompt Enhance.json
Normal file
1
blueprints/Prompt Enhance.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"revision": 0, "last_node_id": 15, "last_link_id": 0, "nodes": [{"id": 15, "type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "pos": [-1490, 2040], "size": [400, 260], "flags": {}, "order": 0, "mode": 0, "inputs": [{"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"label": "reference images", "name": "images", "type": "IMAGE", "link": null}], "outputs": [{"name": "STRING", "type": "STRING", "links": null}], "title": "Prompt Enhance", "properties": {"proxyWidgets": [["-1", "prompt"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [""]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 15, "lastLinkId": 14, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Prompt Enhance", "inputNode": {"id": -10, "bounding": [-2170, 2110, 138.876953125, 80]}, "outputNode": {"id": -20, "bounding": [-640, 2110, 120, 60]}, "inputs": [{"id": "aeab7216-00e0-4528-a09b-bba50845c5a6", "name": "prompt", "type": "STRING", "linkIds": [11], "pos": [-2051.123046875, 2130]}, {"id": "7b73fd36-aa31-4771-9066-f6c83879994b", "name": "images", "type": "IMAGE", "linkIds": [14], "label": "reference images", "pos": [-2051.123046875, 2150]}], "outputs": [{"id": "c7b0d930-68a1-48d1-b496-0519e5837064", "name": "STRING", "type": "STRING", "linkIds": [13], "pos": [-620, 2130]}], "widgets": [], "nodes": [{"id": 11, "type": "GeminiNode", "pos": [-1560, 1990], "size": [470, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": 14}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": null}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 11}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.14.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["", "gemini-3-pro-preview", 42, "randomize", "You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 11, "origin_id": -10, "origin_slot": 0, "target_id": 11, "target_slot": 4, "type": "STRING"}, {"id": 13, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "STRING"}, {"id": 14, "origin_id": -10, "origin_slot": 1, "target_id": 11, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Prompt enhance"}]}, "extra": {}}
|
||||||
1
blueprints/Sharpen.json
Normal file
1
blueprints/Sharpen.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"revision": 0, "last_node_id": 25, "last_link_id": 0, "nodes": [{"id": 25, "type": "621ba4e2-22a8-482d-a369-023753198b7b", "pos": [4610, -790], "size": [230, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Sharpen", "properties": {"proxyWidgets": [["24", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "621ba4e2-22a8-482d-a369-023753198b7b", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 24, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Sharpen", "inputNode": {"id": -10, "bounding": [4090, -825, 120, 60]}, "outputNode": {"id": -20, "bounding": [5150, -825, 120, 60]}, "inputs": [{"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7", "name": "images.image0", "type": "IMAGE", "linkIds": [34], "localized_name": "images.image0", "label": "image", "pos": [4190, -805]}], "outputs": [{"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [35], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [5170, -805]}], "widgets": [], "nodes": [{"id": 24, "type": "PrimitiveFloat", "pos": [4280, -1240], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "strength", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [36]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [0.5]}, {"id": 23, "type": "GLSLShader", "pos": [4570, -1240], "size": [370, 192], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 34}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 36}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [35]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input"]}], "groups": [], "links": [{"id": 36, "origin_id": 24, "origin_slot": 0, "target_id": 23, "target_slot": 2, "type": "FLOAT"}, {"id": 34, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 35, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Sharpen"}]}}
|
||||||
1
blueprints/Text to Audio (ACE-Step 1.5).json
Normal file
1
blueprints/Text to Audio (ACE-Step 1.5).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Text to Image (Z-Image-Turbo).json
Normal file
1
blueprints/Text to Image (Z-Image-Turbo).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Text to Video (Wan 2.2).json
Normal file
1
blueprints/Text to Video (Wan 2.2).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Unsharp Mask.json
Normal file
1
blueprints/Unsharp Mask.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Captioning (Gemini).json
Normal file
1
blueprints/Video Captioning (Gemini).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Inpaint(Wan2.1 VACE).json
Normal file
1
blueprints/Video Inpaint(Wan2.1 VACE).json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Stitch.json
Normal file
1
blueprints/Video Stitch.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Video Upscale(GAN x4).json
Normal file
1
blueprints/Video Upscale(GAN x4).json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"revision": 0, "last_node_id": 13, "last_link_id": 0, "nodes": [{"id": 13, "type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "pos": [1120, 330], "size": [240, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "title": "Video Upscale(GAN x4)", "properties": {"proxyWidgets": [["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 13, "lastLinkId": 19, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Upscale(GAN x4)", "inputNode": {"id": -10, "bounding": [550, 460, 120, 80]}, "outputNode": {"id": -20, "bounding": [1490, 460, 120, 60]}, "inputs": [{"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6", "name": "video", "type": "VIDEO", "linkIds": [10], "localized_name": "video", "pos": [650, 480]}, {"id": "2e23a087-caa8-4d65-99e6-662761aa905a", "name": "model_name", "type": "COMBO", "linkIds": [19], "pos": [650, 500]}], "outputs": [{"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70", "name": "VIDEO", "type": "VIDEO", "linkIds": [15], "localized_name": "VIDEO", "pos": [1510, 480]}], "widgets": [], "nodes": [{"id": 2, "type": "ImageUpscaleWithModel", "pos": [1110, 450], "size": [320, 46], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "upscale_model", "name": "upscale_model", "type": "UPSCALE_MODEL", "link": 1}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 14}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "ImageUpscaleWithModel"}}, {"id": 11, "type": "CreateVideo", "pos": [1110, 550], "size": [320, 78], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 13}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 16}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 12}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [15]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [30]}, {"id": 10, "type": "GetVideoComponents", "pos": [1110, 330], "size": [320, 70], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 10}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [14]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [16]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [12]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 1, "type": "UpscaleModelLoader", "pos": [750, 450], "size": [280, 60], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 19}], "outputs": [{"localized_name": "UPSCALE_MODEL", "name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "UpscaleModelLoader", "models": [{"name": "RealESRGAN_x4plus.safetensors", "url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors", "directory": "upscale_models"}]}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "groups": [], "links": [{"id": 1, "origin_id": 1, "origin_slot": 0, "target_id": 2, "target_slot": 0, "type": "UPSCALE_MODEL"}, {"id": 14, "origin_id": 10, "origin_slot": 0, "target_id": 2, "target_slot": 1, "type": "IMAGE"}, {"id": 13, "origin_id": 2, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 16, "origin_id": 10, "origin_slot": 1, "target_id": 11, "target_slot": 1, "type": "AUDIO"}, {"id": 12, "origin_id": 10, "origin_slot": 2, "target_id": 11, "target_slot": 2, "type": "FLOAT"}, {"id": 10, "origin_id": -10, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "VIDEO"}, {"id": 15, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 19, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Enhance video"}]}, "extra": {}}
|
||||||
@ -1,13 +0,0 @@
|
|||||||
import pickle
|
|
||||||
|
|
||||||
load = pickle.load
|
|
||||||
|
|
||||||
class Empty:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Unpickler(pickle.Unpickler):
|
|
||||||
def find_class(self, module, name):
|
|
||||||
#TODO: safe unpickle
|
|
||||||
if module.startswith("pytorch_lightning"):
|
|
||||||
return Empty
|
|
||||||
return super().find_class(module, name)
|
|
||||||
@ -176,6 +176,8 @@ class InputTypeOptions(TypedDict):
|
|||||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||||
Available after ComfyUI frontend v1.13.4
|
Available after ComfyUI frontend v1.13.4
|
||||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||||
|
gradient_stops: NotRequired[list[list[float]]]
|
||||||
|
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
|
||||||
|
|
||||||
|
|
||||||
class HiddenInputTypeDict(TypedDict):
|
class HiddenInputTypeDict(TypedDict):
|
||||||
|
|||||||
@ -297,6 +297,30 @@ class ControlNet(ControlBase):
|
|||||||
self.model_sampling_current = None
|
self.model_sampling_current = None
|
||||||
super().cleanup()
|
super().cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class QwenFunControlNet(ControlNet):
|
||||||
|
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||||
|
# Fun checkpoints are more sensitive to high strengths in the generic
|
||||||
|
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
|
||||||
|
# unchanged while >1 grows more gently.
|
||||||
|
original_strength = self.strength
|
||||||
|
self.strength = math.sqrt(max(self.strength, 0.0))
|
||||||
|
try:
|
||||||
|
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||||
|
finally:
|
||||||
|
self.strength = original_strength
|
||||||
|
|
||||||
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
|
super().pre_run(model, percent_to_timestep_function)
|
||||||
|
self.set_extra_arg("base_model", model.diffusion_model)
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
|
c.control_model = self.control_model
|
||||||
|
c.control_model_wrapped = self.control_model_wrapped
|
||||||
|
self.copy_to(c)
|
||||||
|
return c
|
||||||
|
|
||||||
class ControlLoraOps:
|
class ControlLoraOps:
|
||||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||||
@ -560,6 +584,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
|||||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
sd = model_config.process_unet_state_dict(sd)
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
@ -605,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
def load_controlnet_qwen_fun(sd, model_options={}):
|
||||||
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
unet_dtype = model_options.get("dtype", weight_dtype)
|
||||||
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
|
|
||||||
|
operations = model_options.get("custom_operations", None)
|
||||||
|
if operations is None:
|
||||||
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||||
|
|
||||||
|
in_features = sd["control_img_in.weight"].shape[1]
|
||||||
|
inner_dim = sd["control_img_in.weight"].shape[0]
|
||||||
|
|
||||||
|
block_weight = sd["control_blocks.0.attn.to_q.weight"]
|
||||||
|
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
|
||||||
|
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
|
||||||
|
|
||||||
|
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
|
||||||
|
control_in_features=in_features,
|
||||||
|
inner_dim=inner_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
num_control_blocks=5,
|
||||||
|
main_model_double=60,
|
||||||
|
injection_layers=(0, 12, 24, 36, 48),
|
||||||
|
operations=operations,
|
||||||
|
device=comfy.model_management.unet_offload_device(),
|
||||||
|
dtype=unet_dtype,
|
||||||
|
)
|
||||||
|
model = controlnet_load_state_dict(model, sd)
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.Wan21()
|
||||||
|
control = QwenFunControlNet(
|
||||||
|
model,
|
||||||
|
compression_ratio=1,
|
||||||
|
latent_format=latent_format,
|
||||||
|
# Fun checkpoints already expect their own 33-channel context handling.
|
||||||
|
# Enabling generic concat_mask injects an extra mask channel at apply-time
|
||||||
|
# and breaks the intended fallback packing path.
|
||||||
|
concat_mask=False,
|
||||||
|
load_device=load_device,
|
||||||
|
manual_cast_dtype=manual_cast_dtype,
|
||||||
|
extra_conds=[],
|
||||||
|
)
|
||||||
|
return control
|
||||||
|
|
||||||
def convert_mistoline(sd):
|
def convert_mistoline(sd):
|
||||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||||
|
|
||||||
@ -682,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|||||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||||
|
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
|
||||||
|
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||||
|
|||||||
@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
|
|||||||
if source_attention_mask.ndim == 2:
|
if source_attention_mask.ndim == 2:
|
||||||
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
||||||
|
|
||||||
x = self.in_proj(self.embed(target_input_ids))
|
|
||||||
context = source_hidden_states
|
context = source_hidden_states
|
||||||
|
x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
|
||||||
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
||||||
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
||||||
position_embeddings = self.rotary_emb(x, position_ids)
|
position_embeddings = self.rotary_emb(x, position_ids)
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from torch import Tensor, nn
|
|||||||
|
|
||||||
from comfy.ldm.flux.layers import (
|
from comfy.ldm.flux.layers import (
|
||||||
MLPEmbedder,
|
MLPEmbedder,
|
||||||
RMSNorm,
|
|
||||||
ModulationOut,
|
ModulationOut,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -29,7 +28,7 @@ class Approximator(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
|
||||||
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -152,6 +152,7 @@ class Chroma(nn.Module):
|
|||||||
transformer_options={},
|
transformer_options={},
|
||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
# running on sequences img
|
# running on sequences img
|
||||||
@ -228,6 +229,7 @@ class Chroma(nn.Module):
|
|||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
|
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if i not in self.skip_dit:
|
if i not in self.skip_dit:
|
||||||
|
|||||||
@ -4,8 +4,6 @@ from functools import lru_cache
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from comfy.ldm.flux.layers import RMSNorm
|
|
||||||
|
|
||||||
|
|
||||||
class NerfEmbedder(nn.Module):
|
class NerfEmbedder(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
|
|||||||
# We now need to generate parameters for 3 matrices.
|
# We now need to generate parameters for 3 matrices.
|
||||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||||
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||||
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
|
||||||
self.mlp_ratio = mlp_ratio
|
self.mlp_ratio = mlp_ratio
|
||||||
|
|
||||||
|
|
||||||
@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
|
|||||||
class NerfFinalLayer(nn.Module):
|
class NerfFinalLayer(nn.Module):
|
||||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
|
|||||||
class NerfFinalLayerConv(nn.Module):
|
class NerfFinalLayerConv(nn.Module):
|
||||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||||
self.conv = operations.Conv2d(
|
self.conv = operations.Conv2d(
|
||||||
in_channels=hidden_size,
|
in_channels=hidden_size,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
|
|||||||
@ -5,9 +5,9 @@ import torch
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from .math import attention, rope
|
from .math import attention, rope
|
||||||
import comfy.ops
|
|
||||||
import comfy.ldm.common_dit
|
|
||||||
|
|
||||||
|
# Fix import for some custom nodes, TODO: delete eventually.
|
||||||
|
RMSNorm = None
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
class EmbedND(nn.Module):
|
||||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||||
@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
|
|||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
|
||||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(torch.nn.Module):
|
class QKNorm(torch.nn.Module):
|
||||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||||
q = self.query_norm(q)
|
q = self.query_norm(q)
|
||||||
@ -169,7 +161,7 @@ class SiLUActivation(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DoubleStreamBlock(nn.Module):
|
class DoubleStreamBlock(nn.Module):
|
||||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.flipped_img_txt = flipped_img_txt
|
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||||
if self.modulation:
|
if self.modulation:
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
@ -206,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||||
|
|
||||||
|
transformer_patches = transformer_options.get("patches", {})
|
||||||
|
extra_options = transformer_options.copy()
|
||||||
|
|
||||||
# prepare image for attention
|
# prepare image for attention
|
||||||
img_modulated = self.img_norm1(img)
|
img_modulated = self.img_norm1(img)
|
||||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||||
@ -224,32 +217,23 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
del txt_qkv
|
del txt_qkv
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
if self.flipped_img_txt:
|
q = torch.cat((txt_q, img_q), dim=2)
|
||||||
q = torch.cat((img_q, txt_q), dim=2)
|
del txt_q, img_q
|
||||||
del img_q, txt_q
|
k = torch.cat((txt_k, img_k), dim=2)
|
||||||
k = torch.cat((img_k, txt_k), dim=2)
|
del txt_k, img_k
|
||||||
del img_k, txt_k
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
v = torch.cat((img_v, txt_v), dim=2)
|
del txt_v, img_v
|
||||||
del img_v, txt_v
|
# run actual attention
|
||||||
# run actual attention
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
attn = attention(q, k, v,
|
del q, k, v
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
||||||
del q, k, v
|
|
||||||
|
|
||||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
if "attn1_output_patch" in transformer_patches:
|
||||||
else:
|
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
||||||
q = torch.cat((txt_q, img_q), dim=2)
|
patch = transformer_patches["attn1_output_patch"]
|
||||||
del txt_q, img_q
|
for p in patch:
|
||||||
k = torch.cat((txt_k, img_k), dim=2)
|
attn = p(attn, extra_options)
|
||||||
del txt_k, img_k
|
|
||||||
v = torch.cat((txt_v, img_v), dim=2)
|
|
||||||
del txt_v, img_v
|
|
||||||
# run actual attention
|
|
||||||
attn = attention(q, k, v,
|
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
||||||
del q, k, v
|
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||||
@ -328,6 +312,9 @@ class SingleStreamBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
mod = vec
|
mod = vec
|
||||||
|
|
||||||
|
transformer_patches = transformer_options.get("patches", {})
|
||||||
|
extra_options = transformer_options.copy()
|
||||||
|
|
||||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
@ -337,6 +324,12 @@ class SingleStreamBlock(nn.Module):
|
|||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
|
if "attn1_output_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn1_output_patch"]
|
||||||
|
for p in patch:
|
||||||
|
attn = p(attn, extra_options)
|
||||||
|
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
if self.yak_mlp:
|
if self.yak_mlp:
|
||||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||||
|
|||||||
@ -16,7 +16,6 @@ from .layers import (
|
|||||||
SingleStreamBlock,
|
SingleStreamBlock,
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
Modulation,
|
Modulation,
|
||||||
RMSNorm
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -81,7 +80,7 @@ class Flux(nn.Module):
|
|||||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
if params.txt_norm:
|
if params.txt_norm:
|
||||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
self.txt_norm = None
|
self.txt_norm = None
|
||||||
|
|
||||||
@ -143,6 +142,7 @@ class Flux(nn.Module):
|
|||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
@ -232,6 +232,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
|
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
|
|||||||
@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
|
|||||||
self.num_heads,
|
self.num_heads,
|
||||||
mlp_ratio=params.mlp_ratio,
|
mlp_ratio=params.mlp_ratio,
|
||||||
qkv_bias=params.qkv_bias,
|
qkv_bias=params.qkv_bias,
|
||||||
flipped_img_txt=True,
|
|
||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
for _ in range(params.depth)
|
for _ in range(params.depth)
|
||||||
@ -305,6 +304,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
control=None,
|
control=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
initial_shape = list(img.shape)
|
initial_shape = list(img.shape)
|
||||||
@ -378,14 +378,14 @@ class HunyuanVideo(nn.Module):
|
|||||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||||
|
|
||||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
img_len = img.shape[1]
|
img_len = img.shape[1]
|
||||||
if txt_mask is not None:
|
if txt_mask is not None:
|
||||||
attn_mask_len = img_len + txt.shape[1]
|
attn_mask_len = img_len + txt.shape[1]
|
||||||
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
||||||
attn_mask[:, 0, img_len:] = txt_mask
|
attn_mask[:, 0, :txt.shape[1]] = txt_mask
|
||||||
else:
|
else:
|
||||||
attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
@ -413,10 +413,11 @@ class HunyuanVideo(nn.Module):
|
|||||||
if add is not None:
|
if add is not None:
|
||||||
img += add
|
img += add
|
||||||
|
|
||||||
img = torch.cat((img, txt), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
|
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
@ -435,9 +436,9 @@ class HunyuanVideo(nn.Module):
|
|||||||
if i < len(control_o):
|
if i < len(control_o):
|
||||||
add = control_o[i]
|
add = control_o[i]
|
||||||
if add is not None:
|
if add is not None:
|
||||||
img[:, : img_len] += add
|
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
|
||||||
|
|
||||||
img = img[:, : img_len]
|
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
|
||||||
if ref_latent is not None:
|
if ref_latent is not None:
|
||||||
img = img[:, ref_latent.shape[1]:]
|
img = img[:, ref_latent.shape[1]:]
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from comfy.ldm.lightricks.model import (
|
|||||||
LTXVModel,
|
LTXVModel,
|
||||||
)
|
)
|
||||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
|
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
class CompressedTimestep:
|
class CompressedTimestep:
|
||||||
@ -450,6 +451,29 @@ class LTXAVModel(LTXVModel):
|
|||||||
operations=self.operations,
|
operations=self.operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||||
|
split_rope=True,
|
||||||
|
double_precision_rope=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.video_embeddings_connector = Embeddings1DConnector(
|
||||||
|
split_rope=True,
|
||||||
|
double_precision_rope=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def preprocess_text_embeds(self, context):
|
||||||
|
if context.shape[-1] == self.caption_channels * 2:
|
||||||
|
return context
|
||||||
|
out_vid = self.video_embeddings_connector(context)[0]
|
||||||
|
out_audio = self.audio_embeddings_connector(context)[0]
|
||||||
|
return torch.concat((out_vid, out_audio), dim=-1)
|
||||||
|
|
||||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
"""Initialize transformer blocks for LTXAV."""
|
"""Initialize transformer blocks for LTXAV."""
|
||||||
self.transformer_blocks = nn.ModuleList(
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
|||||||
@ -157,11 +157,9 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
self.num_learnable_registers = num_learnable_registers
|
self.num_learnable_registers = num_learnable_registers
|
||||||
if self.num_learnable_registers:
|
if self.num_learnable_registers:
|
||||||
self.learnable_registers = nn.Parameter(
|
self.learnable_registers = nn.Parameter(
|
||||||
torch.rand(
|
torch.empty(
|
||||||
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
|
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
* 2.0
|
|
||||||
- 1.0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_fractional_positions(self, indices_grid):
|
def get_fractional_positions(self, indices_grid):
|
||||||
@ -234,7 +232,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
|
|
||||||
return indices
|
return indices
|
||||||
|
|
||||||
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None):
|
||||||
dim = self.inner_dim
|
dim = self.inner_dim
|
||||||
n_elem = 2 # 2 because of cos and sin
|
n_elem = 2 # 2 because of cos and sin
|
||||||
freqs = self.precompute_freqs(indices_grid, spacing)
|
freqs = self.precompute_freqs(indices_grid, spacing)
|
||||||
@ -247,7 +245,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||||
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
|
return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -288,7 +286,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
|
||||||
)
|
)
|
||||||
indices_grid = indices_grid[None, None, :]
|
indices_grid = indices_grid[None, None, :]
|
||||||
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype)
|
||||||
|
|
||||||
# 2. Blocks
|
# 2. Blocks
|
||||||
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
for block_idx, block in enumerate(self.transformer_1d_blocks):
|
||||||
|
|||||||
@ -102,19 +102,7 @@ class VideoConv3d(nn.Module):
|
|||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
|
|
||||||
def interpolate_up(x, scale_factor):
|
def interpolate_up(x, scale_factor):
|
||||||
try:
|
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
||||||
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
|
||||||
except: #operation not implemented for bf16
|
|
||||||
orig_shape = list(x.shape)
|
|
||||||
out_shape = orig_shape[:2]
|
|
||||||
for i in range(len(orig_shape) - 2):
|
|
||||||
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
|
|
||||||
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
|
|
||||||
split = 8
|
|
||||||
l = out.shape[1] // split
|
|
||||||
for i in range(0, out.shape[1], l):
|
|
||||||
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
|
|
||||||
return out
|
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
class Upsample(nn.Module):
|
||||||
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
|
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
|
||||||
|
|||||||
@ -2,6 +2,196 @@ import torch
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from .model import QwenImageTransformer2DModel
|
from .model import QwenImageTransformer2DModel
|
||||||
|
from .model import QwenImageTransformerBlock
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageFunControlBlock(QwenImageTransformerBlock):
|
||||||
|
def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__(
|
||||||
|
dim=dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.has_before_proj = has_before_proj
|
||||||
|
if has_before_proj:
|
||||||
|
self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageFunControlNetModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
control_in_features=132,
|
||||||
|
inner_dim=3072,
|
||||||
|
num_attention_heads=24,
|
||||||
|
attention_head_dim=128,
|
||||||
|
num_control_blocks=5,
|
||||||
|
main_model_double=60,
|
||||||
|
injection_layers=(0, 12, 24, 36, 48),
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.main_model_double = main_model_double
|
||||||
|
self.injection_layers = tuple(injection_layers)
|
||||||
|
# Keep base hint scaling at 1.0 so user-facing strength behaves similarly
|
||||||
|
# to the reference Gen2/VideoX implementation around strength=1.
|
||||||
|
self.hint_scale = 1.0
|
||||||
|
self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.control_blocks = torch.nn.ModuleList([])
|
||||||
|
for i in range(num_control_blocks):
|
||||||
|
self.control_blocks.append(
|
||||||
|
QwenImageFunControlBlock(
|
||||||
|
dim=inner_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
has_before_proj=(i == 0),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_hint_tokens(self, hint):
|
||||||
|
if hint is None:
|
||||||
|
return None
|
||||||
|
if hint.ndim == 4:
|
||||||
|
hint = hint.unsqueeze(2)
|
||||||
|
|
||||||
|
# Fun checkpoints are trained with 33 latent channels before 2x2 packing:
|
||||||
|
# [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
|
||||||
|
# Default behavior (no inpaint input in stock Apply ControlNet) should use
|
||||||
|
# zeros for mask/inpaint branches, matching VideoX fallback semantics.
|
||||||
|
expected_c = self.control_img_in.weight.shape[1] // 4
|
||||||
|
if hint.shape[1] == 16 and expected_c == 33:
|
||||||
|
zeros_mask = torch.zeros_like(hint[:, :1])
|
||||||
|
zeros_inpaint = torch.zeros_like(hint)
|
||||||
|
hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
|
||||||
|
|
||||||
|
bs, c, t, h, w = hint.shape
|
||||||
|
hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(
|
||||||
|
orig_shape[0],
|
||||||
|
orig_shape[1],
|
||||||
|
orig_shape[-3],
|
||||||
|
orig_shape[-2] // 2,
|
||||||
|
2,
|
||||||
|
orig_shape[-1] // 2,
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||||
|
hidden_states = hidden_states.reshape(
|
||||||
|
bs,
|
||||||
|
t * ((h + 1) // 2) * ((w + 1) // 2),
|
||||||
|
c * 4,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_in = self.control_img_in.weight.shape[1]
|
||||||
|
cur_in = hidden_states.shape[-1]
|
||||||
|
if cur_in < expected_in:
|
||||||
|
pad = torch.zeros(
|
||||||
|
(hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
hidden_states = torch.cat([hidden_states, pad], dim=-1)
|
||||||
|
elif cur_in > expected_in:
|
||||||
|
hidden_states = hidden_states[:, :, :expected_in]
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timesteps,
|
||||||
|
context,
|
||||||
|
attention_mask=None,
|
||||||
|
guidance: torch.Tensor = None,
|
||||||
|
hint=None,
|
||||||
|
transformer_options={},
|
||||||
|
base_model=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if base_model is None:
|
||||||
|
raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
|
||||||
|
|
||||||
|
encoder_hidden_states_mask = attention_mask
|
||||||
|
# Keep attention mask disabled inside Fun control blocks to mirror
|
||||||
|
# VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
|
||||||
|
encoder_hidden_states_mask = None
|
||||||
|
|
||||||
|
hidden_states, img_ids, _ = base_model.process_img(x)
|
||||||
|
hint_tokens = self._process_hint_tokens(hint)
|
||||||
|
if hint_tokens is None:
|
||||||
|
raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
|
||||||
|
|
||||||
|
if hint_tokens.shape[1] != hidden_states.shape[1]:
|
||||||
|
max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
|
||||||
|
hint_tokens = hint_tokens[:, :max_tokens]
|
||||||
|
hidden_states = hidden_states[:, :max_tokens]
|
||||||
|
img_ids = img_ids[:, :max_tokens]
|
||||||
|
|
||||||
|
txt_start = round(
|
||||||
|
max(
|
||||||
|
((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||||
|
((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
|
||||||
|
|
||||||
|
hidden_states = base_model.img_in(hidden_states)
|
||||||
|
encoder_hidden_states = base_model.txt_norm(context)
|
||||||
|
encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
|
||||||
|
|
||||||
|
if guidance is not None:
|
||||||
|
guidance = guidance * 1000
|
||||||
|
|
||||||
|
temb = (
|
||||||
|
base_model.time_text_embed(timesteps, hidden_states)
|
||||||
|
if guidance is None
|
||||||
|
else base_model.time_text_embed(timesteps, guidance, hidden_states)
|
||||||
|
)
|
||||||
|
|
||||||
|
c = self.control_img_in(hint_tokens)
|
||||||
|
|
||||||
|
for i, block in enumerate(self.control_blocks):
|
||||||
|
if i == 0:
|
||||||
|
c_in = block.before_proj(c) + hidden_states
|
||||||
|
all_c = []
|
||||||
|
else:
|
||||||
|
all_c = list(torch.unbind(c, dim=0))
|
||||||
|
c_in = all_c.pop(-1)
|
||||||
|
|
||||||
|
encoder_hidden_states, c_out = block(
|
||||||
|
hidden_states=c_in,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
|
temb=temb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
c_skip = block.after_proj(c_out) * self.hint_scale
|
||||||
|
all_c += [c_skip, c_out]
|
||||||
|
c = torch.stack(all_c, dim=0)
|
||||||
|
|
||||||
|
hints = torch.unbind(c, dim=0)[:-1]
|
||||||
|
|
||||||
|
controlnet_block_samples = [None] * self.main_model_double
|
||||||
|
for local_idx, base_idx in enumerate(self.injection_layers):
|
||||||
|
if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
|
||||||
|
controlnet_block_samples[base_idx] = hints[local_idx]
|
||||||
|
|
||||||
|
return {"input": controlnet_block_samples}
|
||||||
|
|
||||||
|
|
||||||
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||||
|
|||||||
@ -374,6 +374,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
|||||||
|
|
||||||
return padded_tensor
|
return padded_tensor
|
||||||
|
|
||||||
|
def calculate_shape(patches, weight, key, original_weights=None):
|
||||||
|
current_shape = weight.shape
|
||||||
|
|
||||||
|
for p in patches:
|
||||||
|
v = p[1]
|
||||||
|
offset = p[3]
|
||||||
|
|
||||||
|
# Offsets restore the old shape; lists force a diff without metadata
|
||||||
|
if offset is not None or isinstance(v, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||||
|
adapter_shape = v.calculate_shape(key)
|
||||||
|
if adapter_shape is not None:
|
||||||
|
current_shape = adapter_shape
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Standard diff logic with padding
|
||||||
|
if len(v) == 2:
|
||||||
|
patch_type, patch_data = v[0], v[1]
|
||||||
|
if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
|
||||||
|
current_shape = patch_data[0].shape
|
||||||
|
|
||||||
|
return current_shape
|
||||||
|
|
||||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
||||||
for p in patches:
|
for p in patches:
|
||||||
strength = p[0]
|
strength = p[0]
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import comfy.utils
|
|||||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||||
sd_out = {}
|
sd_out = {}
|
||||||
for k in sd:
|
for k in sd:
|
||||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
|
||||||
sd_out[k_to] = sd[k]
|
sd_out[k_to] = sd[k]
|
||||||
|
|
||||||
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
||||||
|
|||||||
@ -78,4 +78,4 @@ def interpret_gathered_like(tensors, gathered):
|
|||||||
|
|
||||||
return dest_views
|
return dest_views
|
||||||
|
|
||||||
aimdo_allocator = None
|
aimdo_enabled = False
|
||||||
|
|||||||
@ -178,10 +178,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
||||||
|
|
||||||
context = c_crossattn
|
context = c_crossattn
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype_inference()
|
||||||
|
|
||||||
if self.manual_cast_dtype is not None:
|
|
||||||
dtype = self.manual_cast_dtype
|
|
||||||
|
|
||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
device = xc.device
|
device = xc.device
|
||||||
@ -218,6 +215,13 @@ class BaseModel(torch.nn.Module):
|
|||||||
def get_dtype(self):
|
def get_dtype(self):
|
||||||
return self.diffusion_model.dtype
|
return self.diffusion_model.dtype
|
||||||
|
|
||||||
|
def get_dtype_inference(self):
|
||||||
|
dtype = self.get_dtype()
|
||||||
|
|
||||||
|
if self.manual_cast_dtype is not None:
|
||||||
|
dtype = self.manual_cast_dtype
|
||||||
|
return dtype
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -372,9 +376,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
input_shapes += shape
|
input_shapes += shape
|
||||||
|
|
||||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype_inference()
|
||||||
if self.manual_cast_dtype is not None:
|
|
||||||
dtype = self.manual_cast_dtype
|
|
||||||
#TODO: this needs to be tweaked
|
#TODO: this needs to be tweaked
|
||||||
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||||
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
||||||
@ -986,10 +988,14 @@ class LTXAV(BaseModel):
|
|||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
device = kwargs["device"]
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
cross_attn = kwargs.get("cross_attn", None)
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
|
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
|
||||||
|
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||||
@ -1165,7 +1171,7 @@ class Anima(BaseModel):
|
|||||||
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
||||||
|
|
||||||
if torch.is_inference_mode_enabled(): # if not we are training
|
if torch.is_inference_mode_enabled(): # if not we are training
|
||||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
|
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
|
||||||
else:
|
else:
|
||||||
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
||||||
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||||
|
|||||||
@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string):
|
|||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
def any_suffix_in(keys, prefix, main, suffix_list=[]):
|
||||||
|
for x in suffix_list:
|
||||||
|
if "{}{}{}".format(prefix, main, x) in keys:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||||
context_dim = None
|
context_dim = None
|
||||||
use_linear_in_transformer = False
|
use_linear_in_transformer = False
|
||||||
@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["meanflow_sum"] = False
|
dit_config["meanflow_sum"] = False
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["image_model"] = "flux2"
|
dit_config["image_model"] = "flux2"
|
||||||
@ -241,7 +247,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
|
||||||
|
if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
|
||||||
dit_config["image_model"] = "chroma"
|
dit_config["image_model"] = "chroma"
|
||||||
dit_config["in_channels"] = 64
|
dit_config["in_channels"] = 64
|
||||||
dit_config["out_channels"] = 64
|
dit_config["out_channels"] = 64
|
||||||
@ -249,7 +256,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["out_dim"] = 3072
|
dit_config["out_dim"] = 3072
|
||||||
dit_config["hidden_dim"] = 5120
|
dit_config["hidden_dim"] = 5120
|
||||||
dit_config["n_layers"] = 5
|
dit_config["n_layers"] = 5
|
||||||
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
|
|
||||||
|
if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
|
||||||
dit_config["image_model"] = "chroma_radiance"
|
dit_config["image_model"] = "chroma_radiance"
|
||||||
dit_config["in_channels"] = 3
|
dit_config["in_channels"] = 3
|
||||||
dit_config["out_channels"] = 3
|
dit_config["out_channels"] = 3
|
||||||
@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_depth"] = 4
|
dit_config["nerf_depth"] = 4
|
||||||
dit_config["nerf_max_freqs"] = 8
|
dit_config["nerf_max_freqs"] = 8
|
||||||
dit_config["nerf_tile_size"] = 512
|
dit_config["nerf_tile_size"] = 512
|
||||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
|
||||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||||
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
||||||
dit_config["use_x0"] = True
|
dit_config["use_x0"] = True
|
||||||
@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
else:
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||||
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
|
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
||||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||||
dit_config["txt_ids_dims"] = [1, 2]
|
dit_config["txt_ids_dims"] = [1, 2]
|
||||||
|
|
||||||
|
|||||||
@ -836,7 +836,7 @@ def unet_inital_load_device(parameters, dtype):
|
|||||||
|
|
||||||
mem_dev = get_free_memory(torch_dev)
|
mem_dev = get_free_memory(torch_dev)
|
||||||
mem_cpu = get_free_memory(cpu_dev)
|
mem_cpu = get_free_memory(cpu_dev)
|
||||||
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None:
|
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled:
|
||||||
return torch_dev
|
return torch_dev
|
||||||
else:
|
else:
|
||||||
return cpu_dev
|
return cpu_dev
|
||||||
@ -1121,7 +1121,6 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
|||||||
synchronize()
|
synchronize()
|
||||||
del STREAM_CAST_BUFFERS[offload_stream]
|
del STREAM_CAST_BUFFERS[offload_stream]
|
||||||
del cast_buffer
|
del cast_buffer
|
||||||
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
with wf_context:
|
with wf_context:
|
||||||
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||||
|
|||||||
@ -276,6 +276,7 @@ class ModelPatcher:
|
|||||||
self.is_clip = False
|
self.is_clip = False
|
||||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||||
|
|
||||||
|
self.cached_patcher_init: tuple[Callable, tuple] | None = None
|
||||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
@ -312,8 +313,15 @@ class ModelPatcher:
|
|||||||
def get_free_memory(self, device):
|
def get_free_memory(self, device):
|
||||||
return comfy.model_management.get_free_memory(device)
|
return comfy.model_management.get_free_memory(device)
|
||||||
|
|
||||||
def clone(self):
|
def clone(self, disable_dynamic=False):
|
||||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
class_ = self.__class__
|
||||||
|
model = self.model
|
||||||
|
if self.is_dynamic() and disable_dynamic:
|
||||||
|
class_ = ModelPatcher
|
||||||
|
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
||||||
|
model = temp_model_patcher.model
|
||||||
|
|
||||||
|
n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||||
n.patches = {}
|
n.patches = {}
|
||||||
for k in self.patches:
|
for k in self.patches:
|
||||||
n.patches[k] = self.patches[k][:]
|
n.patches[k] = self.patches[k][:]
|
||||||
@ -367,6 +375,8 @@ class ModelPatcher:
|
|||||||
n.is_clip = self.is_clip
|
n.is_clip = self.is_clip
|
||||||
n.hook_mode = self.hook_mode
|
n.hook_mode = self.hook_mode
|
||||||
|
|
||||||
|
n.cached_patcher_init = self.cached_patcher_init
|
||||||
|
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||||
callback(self, n)
|
callback(self, n)
|
||||||
return n
|
return n
|
||||||
@ -411,13 +421,16 @@ class ModelPatcher:
|
|||||||
def memory_required(self, input_shape):
|
def memory_required(self, input_shape):
|
||||||
return self.model.memory_required(input_shape=input_shape)
|
return self.model.memory_required(input_shape=input_shape)
|
||||||
|
|
||||||
|
def disable_model_cfg1_optimization(self):
|
||||||
|
self.model_options["disable_cfg1_optimization"] = True
|
||||||
|
|
||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||||
else:
|
else:
|
||||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||||
if disable_cfg1_optimization:
|
if disable_cfg1_optimization:
|
||||||
self.model_options["disable_cfg1_optimization"] = True
|
self.disable_model_cfg1_optimization()
|
||||||
|
|
||||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||||
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
||||||
@ -684,18 +697,19 @@ class ModelPatcher:
|
|||||||
for key in list(self.pinned):
|
for key in list(self.pinned):
|
||||||
self.unpin_weight(key)
|
self.unpin_weight(key)
|
||||||
|
|
||||||
def _load_list(self, prio_comfy_cast_weights=False):
|
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
|
||||||
loading = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
params = []
|
default = False
|
||||||
skip = False
|
params = { name: param for name, param in m.named_parameters(recurse=False) }
|
||||||
for name, param in m.named_parameters(recurse=False):
|
|
||||||
params.append(name)
|
|
||||||
for name, param in m.named_parameters(recurse=True):
|
for name, param in m.named_parameters(recurse=True):
|
||||||
if name not in params:
|
if name not in params:
|
||||||
skip = True # skip random weights in non leaf modules
|
default = True # default random weights in non leaf modules
|
||||||
break
|
break
|
||||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
if default and default_device is not None:
|
||||||
|
for param in params.values():
|
||||||
|
param.data = param.data.to(device=default_device)
|
||||||
|
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
module_offload_mem = module_mem
|
module_offload_mem = module_mem
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
@ -1500,7 +1514,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
#with pin and unpin syncrhonization which can be expensive for small weights
|
#with pin and unpin syncrhonization which can be expensive for small weights
|
||||||
#with a high layer rate (e.g. autoregressive LLMs).
|
#with a high layer rate (e.g. autoregressive LLMs).
|
||||||
#prioritize the non-comfy weights (note the order reverse).
|
#prioritize the non-comfy weights (note the order reverse).
|
||||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
|
|
||||||
for x in loading:
|
for x in loading:
|
||||||
@ -1518,8 +1532,10 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
weight, _, _ = get_key_weight(self.model, key)
|
weight, _, _ = get_key_weight(self.model, key)
|
||||||
if weight is None:
|
if weight is None:
|
||||||
return 0
|
return (False, 0)
|
||||||
if key in self.patches:
|
if key in self.patches:
|
||||||
|
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
|
||||||
|
return (True, 0)
|
||||||
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
||||||
num_patches += 1
|
num_patches += 1
|
||||||
else:
|
else:
|
||||||
@ -1533,7 +1549,13 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
|
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
|
||||||
weight._model_dtype = model_dtype
|
weight._model_dtype = model_dtype
|
||||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||||
return comfy.memory_management.vram_aligned_size(geometry)
|
return (False, comfy.memory_management.vram_aligned_size(geometry))
|
||||||
|
|
||||||
|
def force_load_param(self, param_key, device_to):
|
||||||
|
key = key_param_name_to_key(n, param_key)
|
||||||
|
if key in self.backup:
|
||||||
|
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
||||||
|
self.patch_weight_to_device(key, device_to=device_to)
|
||||||
|
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
@ -1541,13 +1563,19 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
m.seed_key = n
|
m.seed_key = n
|
||||||
set_dirty(m, dirty)
|
set_dirty(m, dirty)
|
||||||
|
|
||||||
v_weight_size = 0
|
force_load, v_weight_size = setup_param(self, m, n, "weight")
|
||||||
v_weight_size += setup_param(self, m, n, "weight")
|
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
|
||||||
v_weight_size += setup_param(self, m, n, "bias")
|
force_load = force_load or force_load_bias
|
||||||
|
v_weight_size += v_weight_bias
|
||||||
|
|
||||||
if vbar is not None and not hasattr(m, "_v"):
|
if force_load:
|
||||||
m._v = vbar.alloc(v_weight_size)
|
logging.info(f"Module {n} has resizing Lora - force loading")
|
||||||
allocated_size += v_weight_size
|
force_load_param(self, "weight", device_to)
|
||||||
|
force_load_param(self, "bias", device_to)
|
||||||
|
else:
|
||||||
|
if vbar is not None and not hasattr(m, "_v"):
|
||||||
|
m._v = vbar.alloc(v_weight_size)
|
||||||
|
allocated_size += v_weight_size
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for param in params:
|
for param in params:
|
||||||
@ -1565,6 +1593,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
allocated_size += weight_size
|
allocated_size += weight_size
|
||||||
vbar.set_watermark_limit(allocated_size)
|
vbar.set_watermark_limit(allocated_size)
|
||||||
|
|
||||||
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||||
|
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
@ -1584,7 +1614,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||||
|
|
||||||
def partially_unload_ram(self, ram_to_unload):
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
_, _, _, _, m, _ = x
|
_, _, _, _, m, _ = x
|
||||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||||
@ -1605,6 +1635,13 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
self.partially_unload_ram(1e32)
|
self.partially_unload_ram(1e32)
|
||||||
self.partially_unload(None, 1e32)
|
self.partially_unload(None, 1e32)
|
||||||
|
for m in self.model.modules():
|
||||||
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
|
keys = list(self.backup.keys())
|
||||||
|
for k in keys:
|
||||||
|
bk = self.backup[k]
|
||||||
|
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
assert not force_patch_weights #See above
|
assert not force_patch_weights #See above
|
||||||
|
|||||||
27
comfy/ops.py
27
comfy/ops.py
@ -21,7 +21,6 @@ import logging
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.rmsnorm
|
|
||||||
import json
|
import json
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.pinned_memory
|
import comfy.pinned_memory
|
||||||
@ -80,7 +79,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
|||||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
xfer_dest = None
|
xfer_dest = None
|
||||||
|
|
||||||
@ -171,10 +170,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
||||||
x = lowvram_fn(x)
|
x = lowvram_fn(x)
|
||||||
if (isinstance(orig, QuantizedTensor) and
|
if (isinstance(orig, QuantizedTensor) and
|
||||||
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
|
(want_requant and len(fns) == 0 or update_weight)):
|
||||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||||
if orig.dtype == dtype and len(fns) == 0:
|
if want_requant and len(fns) == 0:
|
||||||
#The layer actually wants our freshly saved QT
|
#The layer actually wants our freshly saved QT
|
||||||
x = y
|
x = y
|
||||||
elif update_weight:
|
elif update_weight:
|
||||||
@ -195,7 +194,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||||
# will add async-offload support to your cast and improve performance.
|
# will add async-offload support to your cast and improve performance.
|
||||||
@ -213,7 +212,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
|
|
||||||
if hasattr(s, "_v"):
|
if hasattr(s, "_v"):
|
||||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
|
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||||
|
|
||||||
if offloadable and (device != s.weight.device or
|
if offloadable and (device != s.weight.device or
|
||||||
(s.bias is not None and device != s.bias.device)):
|
(s.bias is not None and device != s.bias.device)):
|
||||||
@ -463,7 +462,7 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
|
class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
self.bias = None
|
self.bias = None
|
||||||
return None
|
return None
|
||||||
@ -475,8 +474,7 @@ class disable_weight_init:
|
|||||||
weight = None
|
weight = None
|
||||||
bias = None
|
bias = None
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||||
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -829,6 +827,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
else:
|
else:
|
||||||
sd = {}
|
sd = {}
|
||||||
|
|
||||||
|
if not hasattr(self, 'weight'):
|
||||||
|
logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
|
||||||
|
return sd
|
||||||
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
sd["{}bias".format(prefix)] = self.bias
|
sd["{}bias".format(prefix)] = self.bias
|
||||||
|
|
||||||
@ -852,8 +854,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
def _forward(self, input, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input, compute_dtype=None):
|
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
|
||||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
|
||||||
x = self._forward(input, weight, bias)
|
x = self._forward(input, weight, bias)
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
return x
|
return x
|
||||||
@ -883,8 +885,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||||
|
|
||||||
|
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
||||||
output = self.forward_comfy_cast_weights(input, compute_dtype)
|
|
||||||
|
|
||||||
# Reshape output back to 3D if input was 3D
|
# Reshape output back to 3D if input was 3D
|
||||||
if reshaped_3d:
|
if reshaped_3d:
|
||||||
|
|||||||
@ -1,57 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import numbers
|
|
||||||
import logging
|
|
||||||
|
|
||||||
RMSNorm = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
rms_norm_torch = torch.nn.functional.rms_norm
|
|
||||||
RMSNorm = torch.nn.RMSNorm
|
|
||||||
except:
|
|
||||||
rms_norm_torch = None
|
|
||||||
logging.warning("Please update pytorch to use native RMSNorm")
|
|
||||||
|
|
||||||
|
RMSNorm = torch.nn.RMSNorm
|
||||||
|
|
||||||
def rms_norm(x, weight=None, eps=1e-6):
|
def rms_norm(x, weight=None, eps=1e-6):
|
||||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
if weight is None:
|
||||||
if weight is None:
|
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
||||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
|
||||||
else:
|
|
||||||
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
|
||||||
else:
|
else:
|
||||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
if weight is None:
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
|
||||||
|
|
||||||
|
|
||||||
if RMSNorm is None:
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
normalized_shape,
|
|
||||||
eps=1e-6,
|
|
||||||
elementwise_affine=True,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
):
|
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
|
||||||
super().__init__()
|
|
||||||
if isinstance(normalized_shape, numbers.Integral):
|
|
||||||
# mypy error: incompatible types in assignment
|
|
||||||
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
|
||||||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
|
||||||
self.eps = eps
|
|
||||||
self.elementwise_affine = elementwise_affine
|
|
||||||
if self.elementwise_affine:
|
|
||||||
self.weight = torch.nn.Parameter(
|
|
||||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.register_parameter("weight", None)
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return rms_norm(x, self.weight, self.eps)
|
|
||||||
|
|||||||
56
comfy/sd.py
56
comfy/sd.py
@ -423,6 +423,17 @@ class CLIP:
|
|||||||
def get_key_patches(self):
|
def get_key_patches(self):
|
||||||
return self.patcher.get_key_patches()
|
return self.patcher.get_key_patches()
|
||||||
|
|
||||||
|
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||||
|
self.cond_stage_model.reset_clip_options()
|
||||||
|
|
||||||
|
self.load_model()
|
||||||
|
self.cond_stage_model.set_clip_options({"layer": None})
|
||||||
|
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||||
|
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||||
|
|
||||||
|
def decode(self, token_ids, skip_special_tokens=True):
|
||||||
|
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||||
@ -1182,6 +1193,7 @@ class TEModel(Enum):
|
|||||||
JINA_CLIP_2 = 19
|
JINA_CLIP_2 = 19
|
||||||
QWEN3_8B = 20
|
QWEN3_8B = 20
|
||||||
QWEN3_06B = 21
|
QWEN3_06B = 21
|
||||||
|
GEMMA_3_4B_VISION = 22
|
||||||
|
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
@ -1210,7 +1222,10 @@ def detect_te_model(sd):
|
|||||||
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||||
return TEModel.GEMMA_3_12B
|
return TEModel.GEMMA_3_12B
|
||||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||||
return TEModel.GEMMA_3_4B
|
if 'vision_model.embeddings.patch_embedding.weight' in sd:
|
||||||
|
return TEModel.GEMMA_3_4B_VISION
|
||||||
|
else:
|
||||||
|
return TEModel.GEMMA_3_4B
|
||||||
return TEModel.GEMMA_2_2B
|
return TEModel.GEMMA_2_2B
|
||||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||||
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||||
@ -1270,6 +1285,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
else:
|
else:
|
||||||
if "text_projection" in clip_data[i]:
|
if "text_projection" in clip_data[i]:
|
||||||
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
|
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
|
||||||
|
if "lm_head.weight" in clip_data[i]:
|
||||||
|
clip_data[i]["model.lm_head.weight"] = clip_data[i].pop("lm_head.weight") # prefix missing in some models
|
||||||
|
|
||||||
tokenizer_data = {}
|
tokenizer_data = {}
|
||||||
clip_target = EmptyClass()
|
clip_target = EmptyClass()
|
||||||
@ -1335,6 +1352,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
|
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
|
||||||
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
|
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
|
elif te_model == TEModel.GEMMA_3_4B_VISION:
|
||||||
|
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b_vision")
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
|
||||||
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
|
elif te_model == TEModel.GEMMA_3_12B:
|
||||||
|
clip_target.clip = comfy.text_encoders.lt.gemma3_te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.lt.Gemma3_12BTokenizer
|
||||||
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif te_model == TEModel.LLAMA3_8:
|
elif te_model == TEModel.LLAMA3_8:
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
|
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
|
||||||
@ -1505,14 +1530,24 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
|
|
||||||
return (model, clip, vae)
|
return (model, clip, vae)
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||||
if out is None:
|
if out is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||||
|
if output_model:
|
||||||
|
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
|
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
|
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
|
||||||
|
embedding_directory=embedding_directory,
|
||||||
|
model_options=model_options,
|
||||||
|
te_model_options=te_model_options,
|
||||||
|
disable_dynamic=disable_dynamic)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
|
||||||
clip = None
|
clip = None
|
||||||
clipvision = None
|
clipvision = None
|
||||||
vae = None
|
vae = None
|
||||||
@ -1561,7 +1596,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
if output_model:
|
if output_model:
|
||||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||||
|
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||||
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
@ -1612,7 +1648,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable_dynamic=False):
|
||||||
"""
|
"""
|
||||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||||
|
|
||||||
@ -1696,7 +1732,8 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||||||
model_config.optimizations["fp8"] = True
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||||
|
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||||
if not model_management.is_device_cpu(offload_device):
|
if not model_management.is_device_cpu(offload_device):
|
||||||
model.to(offload_device)
|
model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
|
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
|
||||||
@ -1705,12 +1742,13 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||||||
logging.info("left over keys in diffusion model: {}".format(left_over))
|
logging.info("left over keys in diffusion model: {}".format(left_over))
|
||||||
return model_patcher
|
return model_patcher
|
||||||
|
|
||||||
def load_diffusion_model(unet_path, model_options={}):
|
def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
|
||||||
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||||
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
|
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||||
if model is None:
|
if model is None:
|
||||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||||
|
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def load_unet(unet_path, dtype=None):
|
def load_unet(unet_path, dtype=None):
|
||||||
|
|||||||
@ -171,8 +171,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
def process_tokens(self, tokens, device):
|
def process_tokens(self, tokens, device):
|
||||||
end_token = self.special_tokens.get("end", None)
|
end_token = self.special_tokens.get("end", None)
|
||||||
|
pad_token = self.special_tokens.get("pad", -1)
|
||||||
if end_token is None:
|
if end_token is None:
|
||||||
cmp_token = self.special_tokens.get("pad", -1)
|
cmp_token = pad_token
|
||||||
else:
|
else:
|
||||||
cmp_token = end_token
|
cmp_token = end_token
|
||||||
|
|
||||||
@ -186,15 +187,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
other_embeds = []
|
other_embeds = []
|
||||||
eos = False
|
eos = False
|
||||||
index = 0
|
index = 0
|
||||||
|
left_pad = False
|
||||||
for y in x:
|
for y in x:
|
||||||
if isinstance(y, numbers.Integral):
|
if isinstance(y, numbers.Integral):
|
||||||
if eos:
|
token = int(y)
|
||||||
|
if index == 0 and token == pad_token:
|
||||||
|
left_pad = True
|
||||||
|
|
||||||
|
if eos or (left_pad and token == pad_token):
|
||||||
attention_mask.append(0)
|
attention_mask.append(0)
|
||||||
else:
|
else:
|
||||||
attention_mask.append(1)
|
attention_mask.append(1)
|
||||||
token = int(y)
|
left_pad = False
|
||||||
|
|
||||||
tokens_temp += [token]
|
tokens_temp += [token]
|
||||||
if not eos and token == cmp_token:
|
if not eos and token == cmp_token and not left_pad:
|
||||||
if end_token is None:
|
if end_token is None:
|
||||||
attention_mask[-1] = 0
|
attention_mask[-1] = 0
|
||||||
eos = True
|
eos = True
|
||||||
@ -301,6 +308,15 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||||
|
|
||||||
|
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||||
|
if isinstance(tokens, dict):
|
||||||
|
tokens_only = next(iter(tokens.values())) # todo: get this better?
|
||||||
|
else:
|
||||||
|
tokens_only = tokens
|
||||||
|
tokens_only = [[t[0] for t in b] for b in tokens_only]
|
||||||
|
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
|
||||||
|
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
||||||
|
|
||||||
def parse_parentheses(string):
|
def parse_parentheses(string):
|
||||||
result = []
|
result = []
|
||||||
current_item = ""
|
current_item = ""
|
||||||
@ -557,6 +573,8 @@ class SDTokenizer:
|
|||||||
min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
|
min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
|
||||||
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
||||||
|
|
||||||
|
min_length = kwargs.get("min_length", min_length)
|
||||||
|
|
||||||
text = escape_important(text)
|
text = escape_important(text)
|
||||||
if kwargs.get("disable_weights", self.disable_weights):
|
if kwargs.get("disable_weights", self.disable_weights):
|
||||||
parsed_weights = [(text, 1.0)]
|
parsed_weights = [(text, 1.0)]
|
||||||
@ -656,6 +674,9 @@ class SDTokenizer:
|
|||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def decode(self, token_ids, skip_special_tokens=True):
|
||||||
|
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
class SD1Tokenizer:
|
class SD1Tokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
|
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
|
||||||
if name is not None:
|
if name is not None:
|
||||||
@ -679,6 +700,9 @@ class SD1Tokenizer:
|
|||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return getattr(self, self.clip).state_dict()
|
return getattr(self, self.clip).state_dict()
|
||||||
|
|
||||||
|
def decode(self, token_ids, skip_special_tokens=True):
|
||||||
|
return getattr(self, self.clip).decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
class SD1CheckpointClipModel(SDClipModel):
|
class SD1CheckpointClipModel(SDClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
||||||
@ -715,3 +739,6 @@ class SD1ClipModel(torch.nn.Module):
|
|||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return getattr(self, self.clip).load_sd(sd)
|
return getattr(self, self.clip).load_sd(sd)
|
||||||
|
|
||||||
|
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||||
|
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||||
|
|||||||
@ -710,6 +710,15 @@ class Flux(supported_models_base.BASE):
|
|||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
out_sd = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
key_out = k
|
||||||
|
if key_out.endswith("_norm.scale"):
|
||||||
|
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||||
|
out_sd[key_out] = state_dict[k]
|
||||||
|
return out_sd
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
@ -898,11 +907,13 @@ class HunyuanVideo(supported_models_base.BASE):
|
|||||||
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
|
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
|
||||||
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
|
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
|
||||||
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
|
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
|
||||||
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
|
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
|
||||||
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
|
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
|
||||||
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
|
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
|
||||||
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
|
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
|
||||||
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
|
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
|
||||||
|
if key_out.endswith(".scale"):
|
||||||
|
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||||
out_sd[key_out] = state_dict[k]
|
out_sd[key_out] = state_dict[k]
|
||||||
return out_sd
|
return out_sd
|
||||||
|
|
||||||
@ -1264,6 +1275,15 @@ class Hunyuan3Dv2(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.Hunyuan3Dv2
|
latent_format = latent_formats.Hunyuan3Dv2
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
out_sd = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
key_out = k
|
||||||
|
if key_out.endswith(".scale"):
|
||||||
|
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||||
|
out_sd[key_out] = state_dict[k]
|
||||||
|
return out_sd
|
||||||
|
|
||||||
def process_unet_state_dict_for_saving(self, state_dict):
|
def process_unet_state_dict_for_saving(self, state_dict):
|
||||||
replace_prefix = {"": "model."}
|
replace_prefix = {"": "model."}
|
||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
@ -1341,6 +1361,14 @@ class Chroma(supported_models_base.BASE):
|
|||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
out_sd = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
key_out = k
|
||||||
|
if key_out.endswith(".scale"):
|
||||||
|
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||||
|
out_sd[key_out] = state_dict[k]
|
||||||
|
return out_sd
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.Chroma(self, device=device)
|
out = model_base.Chroma(self, device=device)
|
||||||
|
|||||||
@ -10,7 +10,6 @@ import comfy.utils
|
|||||||
def sample_manual_loop_no_classes(
|
def sample_manual_loop_no_classes(
|
||||||
model,
|
model,
|
||||||
ids=None,
|
ids=None,
|
||||||
paddings=[],
|
|
||||||
execution_dtype=None,
|
execution_dtype=None,
|
||||||
cfg_scale: float = 2.0,
|
cfg_scale: float = 2.0,
|
||||||
temperature: float = 0.85,
|
temperature: float = 0.85,
|
||||||
@ -36,9 +35,6 @@ def sample_manual_loop_no_classes(
|
|||||||
|
|
||||||
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
||||||
embeds_batch = embeds.shape[0]
|
embeds_batch = embeds.shape[0]
|
||||||
for i, t in enumerate(paddings):
|
|
||||||
attention_mask[i, :t] = 0
|
|
||||||
attention_mask[i, t:] = 1
|
|
||||||
|
|
||||||
output_audio_codes = []
|
output_audio_codes = []
|
||||||
past_key_values = []
|
past_key_values = []
|
||||||
@ -135,13 +131,11 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
|
|||||||
pos_pad = (len(negative) - len(positive))
|
pos_pad = (len(negative) - len(positive))
|
||||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||||
|
|
||||||
paddings = [pos_pad, neg_pad]
|
|
||||||
ids = [positive, negative]
|
ids = [positive, negative]
|
||||||
else:
|
else:
|
||||||
paddings = []
|
|
||||||
ids = [positive]
|
ids = [positive]
|
||||||
|
|
||||||
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||||
|
|
||||||
|
|
||||||
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
|||||||
@ -33,6 +33,8 @@ class AnimaTokenizer:
|
|||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def decode(self, token_ids, **kwargs):
|
||||||
|
return self.qwen3_06b.decode(token_ids, **kwargs)
|
||||||
|
|
||||||
class Qwen3_06BModel(sd1_clip.SDClipModel):
|
class Qwen3_06BModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
|||||||
@ -3,6 +3,8 @@ import torch.nn as nn
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Any, Tuple
|
from typing import Optional, Any, Tuple
|
||||||
import math
|
import math
|
||||||
|
from tqdm import tqdm
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -103,6 +105,7 @@ class Qwen3_06BConfig:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
|
stop_tokens = [151643, 151645]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen3_06B_ACE15_Config:
|
class Qwen3_06B_ACE15_Config:
|
||||||
@ -126,6 +129,7 @@ class Qwen3_06B_ACE15_Config:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
|
stop_tokens = [151643, 151645]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen3_2B_ACE15_lm_Config:
|
class Qwen3_2B_ACE15_lm_Config:
|
||||||
@ -149,6 +153,7 @@ class Qwen3_2B_ACE15_lm_Config:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
|
stop_tokens = [151643, 151645]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen3_4B_ACE15_lm_Config:
|
class Qwen3_4B_ACE15_lm_Config:
|
||||||
@ -172,6 +177,7 @@ class Qwen3_4B_ACE15_lm_Config:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
|
stop_tokens = [151643, 151645]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen3_4BConfig:
|
class Qwen3_4BConfig:
|
||||||
@ -195,6 +201,7 @@ class Qwen3_4BConfig:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
|
stop_tokens = [151643, 151645]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen3_8BConfig:
|
class Qwen3_8BConfig:
|
||||||
@ -218,6 +225,7 @@ class Qwen3_8BConfig:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
|
stop_tokens = [151643, 151645]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Ovis25_2BConfig:
|
class Ovis25_2BConfig:
|
||||||
@ -288,6 +296,7 @@ class Gemma2_2B_Config:
|
|||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
|
stop_tokens = [1]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Gemma3_4B_Config:
|
class Gemma3_4B_Config:
|
||||||
@ -312,6 +321,14 @@ class Gemma3_4B_Config:
|
|||||||
rope_scale = [8.0, 1.0]
|
rope_scale = [8.0, 1.0]
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
|
stop_tokens = [1, 106]
|
||||||
|
|
||||||
|
GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Gemma3_4B_Vision_Config(Gemma3_4B_Config):
|
||||||
|
vision_config = GEMMA3_VISION_CONFIG
|
||||||
|
mm_tokens_per_image = 256
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Gemma3_12B_Config:
|
class Gemma3_12B_Config:
|
||||||
@ -336,8 +353,9 @@ class Gemma3_12B_Config:
|
|||||||
rope_scale = [8.0, 1.0]
|
rope_scale = [8.0, 1.0]
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = False
|
||||||
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
vision_config = GEMMA3_VISION_CONFIG
|
||||||
mm_tokens_per_image = 256
|
mm_tokens_per_image = 256
|
||||||
|
stop_tokens = [1, 106]
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||||
@ -355,13 +373,6 @@ class RMSNorm(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
||||||
if not isinstance(theta, list):
|
if not isinstance(theta, list):
|
||||||
theta = [theta]
|
theta = [theta]
|
||||||
@ -390,20 +401,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
|
|||||||
else:
|
else:
|
||||||
cos = cos.unsqueeze(1)
|
cos = cos.unsqueeze(1)
|
||||||
sin = sin.unsqueeze(1)
|
sin = sin.unsqueeze(1)
|
||||||
out.append((cos, sin))
|
sin_split = sin.shape[-1] // 2
|
||||||
|
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
|
||||||
|
|
||||||
if len(out) == 1:
|
if len(out) == 1:
|
||||||
return out[0]
|
return out[0]
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def apply_rope(xq, xk, freqs_cis):
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
org_dtype = xq.dtype
|
org_dtype = xq.dtype
|
||||||
cos = freqs_cis[0]
|
cos = freqs_cis[0]
|
||||||
sin = freqs_cis[1]
|
sin = freqs_cis[1]
|
||||||
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
nsin = freqs_cis[2]
|
||||||
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
|
||||||
|
q_embed = (xq * cos)
|
||||||
|
q_split = q_embed.shape[-1] // 2
|
||||||
|
q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
|
||||||
|
q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
|
||||||
|
|
||||||
|
k_embed = (xk * cos)
|
||||||
|
k_split = k_embed.shape[-1] // 2
|
||||||
|
k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
|
||||||
|
k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
|
||||||
|
|
||||||
return q_embed.to(org_dtype), k_embed.to(org_dtype)
|
return q_embed.to(org_dtype), k_embed.to(org_dtype)
|
||||||
|
|
||||||
|
|
||||||
@ -438,8 +459,10 @@ class Attention(nn.Module):
|
|||||||
freqs_cis: Optional[torch.Tensor] = None,
|
freqs_cis: Optional[torch.Tensor] = None,
|
||||||
optimized_attention=None,
|
optimized_attention=None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
):
|
):
|
||||||
batch_size, seq_length, _ = hidden_states.shape
|
batch_size, seq_length, _ = hidden_states.shape
|
||||||
|
|
||||||
xq = self.q_proj(hidden_states)
|
xq = self.q_proj(hidden_states)
|
||||||
xk = self.k_proj(hidden_states)
|
xk = self.k_proj(hidden_states)
|
||||||
xv = self.v_proj(hidden_states)
|
xv = self.v_proj(hidden_states)
|
||||||
@ -474,6 +497,11 @@ class Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
present_key_value = (xk, xv, index + num_tokens)
|
present_key_value = (xk, xv, index + num_tokens)
|
||||||
|
|
||||||
|
if sliding_window is not None and xk.shape[2] > sliding_window:
|
||||||
|
xk = xk[:, :, -sliding_window:]
|
||||||
|
xv = xv[:, :, -sliding_window:]
|
||||||
|
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
|
||||||
|
|
||||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||||
|
|
||||||
@ -556,10 +584,12 @@ class TransformerBlockGemma2(nn.Module):
|
|||||||
optimized_attention=None,
|
optimized_attention=None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
):
|
):
|
||||||
|
sliding_window = None
|
||||||
if self.transformer_type == 'gemma3':
|
if self.transformer_type == 'gemma3':
|
||||||
if self.sliding_attention:
|
if self.sliding_attention:
|
||||||
|
sliding_window = self.sliding_attention
|
||||||
if x.shape[1] > self.sliding_attention:
|
if x.shape[1] > self.sliding_attention:
|
||||||
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
|
sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype)
|
||||||
sliding_mask.tril_(diagonal=-self.sliding_attention)
|
sliding_mask.tril_(diagonal=-self.sliding_attention)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask + sliding_mask
|
attention_mask = attention_mask + sliding_mask
|
||||||
@ -578,6 +608,7 @@ class TransformerBlockGemma2(nn.Module):
|
|||||||
freqs_cis=freqs_cis,
|
freqs_cis=freqs_cis,
|
||||||
optimized_attention=optimized_attention,
|
optimized_attention=optimized_attention,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
|
sliding_window=sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
x = self.post_attention_layernorm(x)
|
x = self.post_attention_layernorm(x)
|
||||||
@ -762,6 +793,107 @@ class BaseLlama:
|
|||||||
def forward(self, input_ids, *args, **kwargs):
|
def forward(self, input_ids, *args, **kwargs):
|
||||||
return self.model(input_ids, *args, **kwargs)
|
return self.model(input_ids, *args, **kwargs)
|
||||||
|
|
||||||
|
class BaseGenerate:
|
||||||
|
def logits(self, x):
|
||||||
|
input = x[:, -1:]
|
||||||
|
if hasattr(self.model, "lm_head"):
|
||||||
|
module = self.model.lm_head
|
||||||
|
else:
|
||||||
|
module = self.model.embed_tokens
|
||||||
|
|
||||||
|
offload_stream = None
|
||||||
|
if module.comfy_cast_weights:
|
||||||
|
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
|
||||||
|
else:
|
||||||
|
weight = self.model.embed_tokens.weight.to(x)
|
||||||
|
|
||||||
|
x = torch.nn.functional.linear(input, weight, None)
|
||||||
|
|
||||||
|
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
|
||||||
|
device = embeds.device
|
||||||
|
model_config = self.model.config
|
||||||
|
|
||||||
|
if stop_tokens is None:
|
||||||
|
stop_tokens = self.model.config.stop_tokens
|
||||||
|
|
||||||
|
if execution_dtype is None:
|
||||||
|
if comfy.model_management.should_use_bf16(device):
|
||||||
|
execution_dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
execution_dtype = torch.float32
|
||||||
|
embeds = embeds.to(execution_dtype)
|
||||||
|
|
||||||
|
if embeds.ndim == 2:
|
||||||
|
embeds = embeds.unsqueeze(0)
|
||||||
|
|
||||||
|
past_key_values = [] #kv_cache init
|
||||||
|
max_cache_len = embeds.shape[1] + max_length
|
||||||
|
for x in range(model_config.num_hidden_layers):
|
||||||
|
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||||
|
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||||
|
|
||||||
|
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
|
||||||
|
|
||||||
|
generated_token_ids = []
|
||||||
|
pbar = comfy.utils.ProgressBar(max_length)
|
||||||
|
|
||||||
|
# Generation loop
|
||||||
|
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||||
|
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
||||||
|
logits = self.logits(x)[:, -1]
|
||||||
|
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
|
||||||
|
token_id = next_token[0].item()
|
||||||
|
generated_token_ids.append(token_id)
|
||||||
|
|
||||||
|
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
if token_id in stop_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
return generated_token_ids
|
||||||
|
|
||||||
|
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
|
||||||
|
|
||||||
|
if not do_sample or temperature == 0.0:
|
||||||
|
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Sampling mode
|
||||||
|
if repetition_penalty != 1.0:
|
||||||
|
for i in range(logits.shape[0]):
|
||||||
|
for token_id in set(token_history):
|
||||||
|
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
|
||||||
|
|
||||||
|
if temperature != 1.0:
|
||||||
|
logits = logits / temperature
|
||||||
|
|
||||||
|
if top_k > 0:
|
||||||
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||||
|
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
||||||
|
|
||||||
|
if min_p > 0.0:
|
||||||
|
probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
|
top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True)
|
||||||
|
min_threshold = min_p * top_probs
|
||||||
|
indices_to_remove = probs_before_filter < min_threshold
|
||||||
|
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
||||||
|
|
||||||
|
if top_p < 1.0:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
|
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
sorted_indices_to_remove[..., 0] = False
|
||||||
|
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||||
|
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
|
||||||
|
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
||||||
|
|
||||||
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
return torch.multinomial(probs, num_samples=1, generator=generator)
|
||||||
|
|
||||||
class BaseQwen3:
|
class BaseQwen3:
|
||||||
def logits(self, x):
|
def logits(self, x):
|
||||||
input = x[:, -1:]
|
input = x[:, -1:]
|
||||||
@ -805,7 +937,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module):
|
class Qwen3_06B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = Qwen3_06BConfig(**config_dict)
|
config = Qwen3_06BConfig(**config_dict)
|
||||||
@ -832,7 +964,7 @@ class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module):
|
class Qwen3_4B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = Qwen3_4BConfig(**config_dict)
|
config = Qwen3_4BConfig(**config_dict)
|
||||||
@ -850,7 +982,7 @@ class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module):
|
class Qwen3_8B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = Qwen3_8BConfig(**config_dict)
|
config = Qwen3_8BConfig(**config_dict)
|
||||||
@ -868,7 +1000,7 @@ class Ovis25_2B(BaseLlama, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = Qwen25_7BVLI_Config(**config_dict)
|
config = Qwen25_7BVLI_Config(**config_dict)
|
||||||
@ -878,6 +1010,9 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
|||||||
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
|
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
# todo: should this be tied or not?
|
||||||
|
#self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
def preprocess_embed(self, embed, device):
|
def preprocess_embed(self, embed, device):
|
||||||
if embed["type"] == "image":
|
if embed["type"] == "image":
|
||||||
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
|
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
|
||||||
@ -911,7 +1046,7 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
|||||||
|
|
||||||
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)
|
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)
|
||||||
|
|
||||||
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
class Gemma2_2B(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = Gemma2_2B_Config(**config_dict)
|
config = Gemma2_2B_Config(**config_dict)
|
||||||
@ -920,7 +1055,7 @@ class Gemma2_2B(BaseLlama, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
class Gemma3_4B(BaseLlama, torch.nn.Module):
|
class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = Gemma3_4B_Config(**config_dict)
|
config = Gemma3_4B_Config(**config_dict)
|
||||||
@ -929,7 +1064,25 @@ class Gemma3_4B(BaseLlama, torch.nn.Module):
|
|||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
class Gemma3_12B(BaseLlama, torch.nn.Module):
|
class Gemma3_4B_Vision(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
config = Gemma3_4B_Vision_Config(**config_dict)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
|
||||||
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
|
||||||
|
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
|
||||||
|
self.image_size = config.vision_config["image_size"]
|
||||||
|
|
||||||
|
def preprocess_embed(self, embed, device):
|
||||||
|
if embed["type"] == "image":
|
||||||
|
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
|
||||||
|
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
class Gemma3_12B(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = Gemma3_12B_Config(**config_dict)
|
config = Gemma3_12B_Config(**config_dict)
|
||||||
|
|||||||
@ -3,9 +3,9 @@ import os
|
|||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
from .spiece_tokenizer import SPieceTokenizer
|
from .spiece_tokenizer import SPieceTokenizer
|
||||||
import comfy.text_encoders.genmo
|
import comfy.text_encoders.genmo
|
||||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import math
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -22,46 +22,86 @@ def ltxv_te(*args, **kwargs):
|
|||||||
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
class Gemma3_Tokenizer():
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, image=None, llama_template=None, skip_template=True, **kwargs):
|
||||||
|
self.llama_template = "<start_of_turn>system\nYou are a helpful assistant.<end_of_turn>\n<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n"
|
||||||
|
self.llama_template_images = "<start_of_turn>system\nYou are a helpful assistant.<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>{}<end_of_turn>\n\n<start_of_turn>model\n"
|
||||||
|
|
||||||
|
if image is None:
|
||||||
|
images = []
|
||||||
|
else:
|
||||||
|
samples = image.movedim(-1, 1)
|
||||||
|
total = int(896 * 896)
|
||||||
|
|
||||||
|
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||||
|
width = round(samples.shape[3] * scale_by)
|
||||||
|
height = round(samples.shape[2] * scale_by)
|
||||||
|
|
||||||
|
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
|
||||||
|
images = [s[:, :, :, :3]]
|
||||||
|
|
||||||
|
if text.startswith('<start_of_turn>'):
|
||||||
|
skip_template = True
|
||||||
|
|
||||||
|
if skip_template:
|
||||||
|
llama_text = text
|
||||||
|
else:
|
||||||
|
if llama_template is None:
|
||||||
|
if len(images) > 0:
|
||||||
|
llama_text = self.llama_template_images.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = self.llama_template.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = llama_template.format(text)
|
||||||
|
|
||||||
|
text_tokens = super().tokenize_with_weights(llama_text, return_word_ids)
|
||||||
|
|
||||||
|
if len(images) > 0:
|
||||||
|
embed_count = 0
|
||||||
|
for r in text_tokens:
|
||||||
|
for i, token in enumerate(r):
|
||||||
|
if token[0] == 262144 and embed_count < len(images):
|
||||||
|
r[i] = ({"type": "image", "data": images[embed_count]},) + token[1:]
|
||||||
|
embed_count += 1
|
||||||
|
return text_tokens
|
||||||
|
|
||||||
|
class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
|
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
|
||||||
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
|
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
|
||||||
|
|
||||||
|
|
||||||
class Gemma3_12BModel(sd1_clip.SDClipModel):
|
class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||||
if llama_quantization_metadata is not None:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
self.dtypes = set()
|
||||||
|
self.dtypes.add(dtype)
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs):
|
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||||
text = llama_template.format(text)
|
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||||
text_tokens = super().tokenize_with_weights(text, return_word_ids)
|
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
||||||
embed_count = 0
|
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||||
for k in text_tokens:
|
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
|
||||||
tt = text_tokens[k]
|
|
||||||
for r in tt:
|
|
||||||
for i in range(len(r)):
|
|
||||||
if r[i][0] == 262144:
|
|
||||||
if image_embeds is not None and embed_count < image_embeds.shape[0]:
|
|
||||||
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
|
|
||||||
embed_count += 1
|
|
||||||
return text_tokens
|
|
||||||
|
|
||||||
class LTXAVTEModel(torch.nn.Module):
|
class LTXAVTEModel(torch.nn.Module):
|
||||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
self.dtypes.add(dtype)
|
self.dtypes.add(dtype)
|
||||||
|
self.compat_mode = False
|
||||||
|
|
||||||
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
||||||
self.dtypes.add(dtype_llama)
|
self.dtypes.add(dtype_llama)
|
||||||
@ -69,6 +109,11 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
operations = self.gemma3_12b.operations # TODO
|
operations = self.gemma3_12b.operations # TODO
|
||||||
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def enable_compat_mode(self): # TODO: remove
|
||||||
|
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||||
|
operations = self.gemma3_12b.operations
|
||||||
|
dtype = self.text_embedding_projection.weight.dtype
|
||||||
|
device = self.text_embedding_projection.weight.device
|
||||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||||
split_rope=True,
|
split_rope=True,
|
||||||
double_precision_rope=True,
|
double_precision_rope=True,
|
||||||
@ -84,6 +129,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
self.compat_mode = True
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
self.execution_device = options.get("execution_device", self.execution_device)
|
self.execution_device = options.get("execution_device", self.execution_device)
|
||||||
@ -97,6 +143,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
||||||
|
|
||||||
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
||||||
|
out = out[:, :, -torch.sum(extra["attention_mask"]).item():]
|
||||||
out_device = out.device
|
out_device = out.device
|
||||||
if comfy.model_management.should_use_bf16(self.execution_device):
|
if comfy.model_management.should_use_bf16(self.execution_device):
|
||||||
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
||||||
@ -105,30 +152,45 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
out = out.reshape((out.shape[0], out.shape[1], -1))
|
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||||
out = self.text_embedding_projection(out)
|
out = self.text_embedding_projection(out)
|
||||||
out = out.float()
|
out = out.float()
|
||||||
out_vid = self.video_embeddings_connector(out)[0]
|
|
||||||
out_audio = self.audio_embeddings_connector(out)[0]
|
if self.compat_mode:
|
||||||
out = torch.concat((out_vid, out_audio), dim=-1)
|
out_vid = self.video_embeddings_connector(out)[0]
|
||||||
|
out_audio = self.audio_embeddings_connector(out)[0]
|
||||||
|
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||||
|
|
||||||
return out.to(out_device), pooled
|
return out.to(out_device), pooled
|
||||||
|
|
||||||
|
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||||
|
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||||
return self.gemma3_12b.load_sd(sd)
|
return self.gemma3_12b.load_sd(sd)
|
||||||
else:
|
else:
|
||||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
|
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True)
|
||||||
if len(sdo) == 0:
|
if len(sdo) == 0:
|
||||||
sdo = sd
|
sdo = sd
|
||||||
|
|
||||||
missing_all = []
|
missing_all = []
|
||||||
unexpected_all = []
|
unexpected_all = []
|
||||||
|
|
||||||
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
|
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection)]:
|
||||||
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
|
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
|
||||||
if component_sd:
|
if component_sd:
|
||||||
missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||||
missing_all.extend([f"{prefix}{k}" for k in missing])
|
missing_all.extend([f"{prefix}{k}" for k in missing])
|
||||||
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
|
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
|
||||||
|
|
||||||
|
if "model.diffusion_model.audio_embeddings_connector.transformer_1d_blocks.2.attn1.to_q.bias" not in sd: # TODO: remove
|
||||||
|
ww = sd.get("model.diffusion_model.audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.bias", None)
|
||||||
|
if ww is not None:
|
||||||
|
if ww.shape[0] == 3840:
|
||||||
|
self.enable_compat_mode()
|
||||||
|
sdv = comfy.utils.state_dict_prefix_replace(sd, {"model.diffusion_model.video_embeddings_connector.": ""}, filter_keys=True)
|
||||||
|
self.video_embeddings_connector.load_state_dict(sdv, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||||
|
sda = comfy.utils.state_dict_prefix_replace(sd, {"model.diffusion_model.audio_embeddings_connector.": ""}, filter_keys=True)
|
||||||
|
self.audio_embeddings_connector.load_state_dict(sda, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||||
|
|
||||||
return (missing_all, unexpected_all)
|
return (missing_all, unexpected_all)
|
||||||
|
|
||||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||||
@ -138,6 +200,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
|
|
||||||
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
||||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||||
|
num_tokens = max(num_tokens, 64)
|
||||||
return num_tokens * constant * 1024 * 1024
|
return num_tokens * constant * 1024 * 1024
|
||||||
|
|
||||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
@ -150,3 +213,14 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
|||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
return LTXAVTEModel_
|
return LTXAVTEModel_
|
||||||
|
|
||||||
|
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
class Gemma3_12BModel_(Gemma3_12BModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return Gemma3_12BModel_
|
||||||
|
|||||||
@ -1,23 +1,23 @@
|
|||||||
from comfy import sd1_clip
|
from comfy import sd1_clip
|
||||||
from .spiece_tokenizer import SPieceTokenizer
|
from .spiece_tokenizer import SPieceTokenizer
|
||||||
import comfy.text_encoders.llama
|
import comfy.text_encoders.llama
|
||||||
|
from comfy.text_encoders.lt import Gemma3_Tokenizer
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
special_tokens = {"<end_of_turn>": 107}
|
||||||
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
|
||||||
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
|
class Gemma3_4BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
|
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
|
||||||
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, disable_weights=True, tokenizer_data=tokenizer_data)
|
||||||
def state_dict(self):
|
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
|
||||||
|
|
||||||
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
|
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -40,6 +40,20 @@ class Gemma3_4BModel(sd1_clip.SDClipModel):
|
|||||||
|
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
def process_tokens(self, tokens, device):
|
||||||
|
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
|
||||||
|
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||||
|
return embeds
|
||||||
|
|
||||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
|
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
|
||||||
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||||
@ -50,6 +64,8 @@ def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b
|
|||||||
model = Gemma2_2BModel
|
model = Gemma2_2BModel
|
||||||
elif model_type == "gemma3_4b":
|
elif model_type == "gemma3_4b":
|
||||||
model = Gemma3_4BModel
|
model = Gemma3_4BModel
|
||||||
|
elif model_type == "gemma3_4b_vision":
|
||||||
|
model = Gemma3_4B_Vision_Model
|
||||||
|
|
||||||
class LuminaTEModel_(LuminaModel):
|
class LuminaTEModel_(LuminaModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
|||||||
@ -6,9 +6,10 @@ class SPieceTokenizer:
|
|||||||
def from_pretrained(path, **kwargs):
|
def from_pretrained(path, **kwargs):
|
||||||
return SPieceTokenizer(path, **kwargs)
|
return SPieceTokenizer(path, **kwargs)
|
||||||
|
|
||||||
def __init__(self, tokenizer_path, add_bos=False, add_eos=True):
|
def __init__(self, tokenizer_path, add_bos=False, add_eos=True, special_tokens=None):
|
||||||
self.add_bos = add_bos
|
self.add_bos = add_bos
|
||||||
self.add_eos = add_eos
|
self.add_eos = add_eos
|
||||||
|
self.special_tokens = special_tokens
|
||||||
import sentencepiece
|
import sentencepiece
|
||||||
if torch.is_tensor(tokenizer_path):
|
if torch.is_tensor(tokenizer_path):
|
||||||
tokenizer_path = tokenizer_path.numpy().tobytes()
|
tokenizer_path = tokenizer_path.numpy().tobytes()
|
||||||
@ -27,8 +28,32 @@ class SPieceTokenizer:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def __call__(self, string):
|
def __call__(self, string):
|
||||||
|
if self.special_tokens is not None:
|
||||||
|
import re
|
||||||
|
special_tokens_pattern = '|'.join(re.escape(token) for token in self.special_tokens.keys())
|
||||||
|
if special_tokens_pattern and re.search(special_tokens_pattern, string):
|
||||||
|
parts = re.split(f'({special_tokens_pattern})', string)
|
||||||
|
result = []
|
||||||
|
for part in parts:
|
||||||
|
if not part:
|
||||||
|
continue
|
||||||
|
if part in self.special_tokens:
|
||||||
|
result.append(self.special_tokens[part])
|
||||||
|
else:
|
||||||
|
encoded = self.tokenizer.encode(part, add_bos=False, add_eos=False)
|
||||||
|
result.extend(encoded)
|
||||||
|
return {"input_ids": result}
|
||||||
|
|
||||||
out = self.tokenizer.encode(string)
|
out = self.tokenizer.encode(string)
|
||||||
return {"input_ids": out}
|
return {"input_ids": out}
|
||||||
|
|
||||||
|
def decode(self, token_ids, skip_special_tokens=False):
|
||||||
|
|
||||||
|
if skip_special_tokens and self.special_tokens:
|
||||||
|
special_token_ids = set(self.special_tokens.values())
|
||||||
|
token_ids = [tid for tid in token_ids if tid not in special_token_ids]
|
||||||
|
|
||||||
|
return self.tokenizer.decode(token_ids)
|
||||||
|
|
||||||
def serialize_model(self):
|
def serialize_model(self):
|
||||||
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))
|
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))
|
||||||
|
|||||||
@ -20,7 +20,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import struct
|
import struct
|
||||||
import comfy.checkpoint_pickle
|
import comfy.memory_management
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -38,26 +38,26 @@ import warnings
|
|||||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||||
DISABLE_MMAP = args.disable_mmap
|
DISABLE_MMAP = args.disable_mmap
|
||||||
|
|
||||||
ALWAYS_SAFE_LOAD = False
|
|
||||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
if True: # ckpt/pt file whitelist for safe loading of old sd files
|
||||||
class ModelCheckpoint:
|
class ModelCheckpoint:
|
||||||
pass
|
pass
|
||||||
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
||||||
|
|
||||||
def scalar(*args, **kwargs):
|
def scalar(*args, **kwargs):
|
||||||
from numpy.core.multiarray import scalar as sc
|
return None
|
||||||
return sc(*args, **kwargs)
|
|
||||||
scalar.__module__ = "numpy.core.multiarray"
|
scalar.__module__ = "numpy.core.multiarray"
|
||||||
|
|
||||||
from numpy import dtype
|
from numpy import dtype
|
||||||
from numpy.dtypes import Float64DType
|
from numpy.dtypes import Float64DType
|
||||||
from _codecs import encode
|
|
||||||
|
def encode(*args, **kwargs): # no longer necessary on newer torch
|
||||||
|
return None
|
||||||
|
encode.__module__ = "_codecs"
|
||||||
|
|
||||||
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
||||||
ALWAYS_SAFE_LOAD = True
|
|
||||||
logging.info("Checkpoint files will always be loaded safely.")
|
logging.info("Checkpoint files will always be loaded safely.")
|
||||||
else:
|
|
||||||
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
|
||||||
|
|
||||||
# Current as of safetensors 0.7.0
|
# Current as of safetensors 0.7.0
|
||||||
_TYPES = {
|
_TYPES = {
|
||||||
@ -140,11 +140,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
|||||||
if MMAP_TORCH_FILES:
|
if MMAP_TORCH_FILES:
|
||||||
torch_args["mmap"] = True
|
torch_args["mmap"] = True
|
||||||
|
|
||||||
if safe_load or ALWAYS_SAFE_LOAD:
|
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
|
||||||
else:
|
|
||||||
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
|
|
||||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
|
||||||
if "state_dict" in pl_sd:
|
if "state_dict" in pl_sd:
|
||||||
sd = pl_sd["state_dict"]
|
sd = pl_sd["state_dict"]
|
||||||
else:
|
else:
|
||||||
@ -675,10 +672,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
"ff_context.linear_in.bias": "txt_mlp.0.bias",
|
"ff_context.linear_in.bias": "txt_mlp.0.bias",
|
||||||
"ff_context.linear_out.weight": "txt_mlp.2.weight",
|
"ff_context.linear_out.weight": "txt_mlp.2.weight",
|
||||||
"ff_context.linear_out.bias": "txt_mlp.2.bias",
|
"ff_context.linear_out.bias": "txt_mlp.2.bias",
|
||||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
"attn.norm_q.weight": "img_attn.norm.query_norm.weight",
|
||||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
"attn.norm_k.weight": "img_attn.norm.key_norm.weight",
|
||||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
|
||||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in block_map:
|
for k in block_map:
|
||||||
@ -701,8 +698,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
"norm.linear.bias": "modulation.lin.bias",
|
"norm.linear.bias": "modulation.lin.bias",
|
||||||
"proj_out.weight": "linear2.weight",
|
"proj_out.weight": "linear2.weight",
|
||||||
"proj_out.bias": "linear2.bias",
|
"proj_out.bias": "linear2.bias",
|
||||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
"attn.norm_q.weight": "norm.query_norm.weight",
|
||||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
"attn.norm_k.weight": "norm.key_norm.weight",
|
||||||
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
|
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
|
||||||
"attn.to_out.weight": "linear2.weight", # Flux 2
|
"attn.to_out.weight": "linear2.weight", # Flux 2
|
||||||
}
|
}
|
||||||
@ -1157,7 +1154,7 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
|
|||||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
||||||
|
|
||||||
def model_trange(*args, **kwargs):
|
def model_trange(*args, **kwargs):
|
||||||
if comfy.memory_management.aimdo_allocator is None:
|
if not comfy.memory_management.aimdo_enabled:
|
||||||
return trange(*args, **kwargs)
|
return trange(*args, **kwargs)
|
||||||
|
|
||||||
pbar = trange(*args, **kwargs, smoothing=1.0)
|
pbar = trange(*args, **kwargs, smoothing=1.0)
|
||||||
@ -1421,3 +1418,11 @@ def deepcopy_list_dict(obj, memo=None):
|
|||||||
|
|
||||||
memo[obj_id] = res
|
memo[obj_id] = res
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
|
||||||
|
"""Normalize image embeddings to match text embedding scale"""
|
||||||
|
for info in embeds_info:
|
||||||
|
if info.get("type") == "image":
|
||||||
|
start_idx = info["index"]
|
||||||
|
end_idx = start_idx + info["size"]
|
||||||
|
embeds[:, start_idx:end_idx, :] /= scale_factor
|
||||||
|
|||||||
@ -49,6 +49,12 @@ class WeightAdapterBase:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def calculate_shape(
|
||||||
|
self,
|
||||||
|
key
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
def calculate_weight(
|
def calculate_weight(
|
||||||
self,
|
self,
|
||||||
weight,
|
weight,
|
||||||
|
|||||||
@ -214,6 +214,13 @@ class LoRAAdapter(WeightAdapterBase):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def calculate_shape(
|
||||||
|
self,
|
||||||
|
key
|
||||||
|
):
|
||||||
|
reshape = self.weights[5]
|
||||||
|
return tuple(reshape) if reshape is not None else None
|
||||||
|
|
||||||
def calculate_weight(
|
def calculate_weight(
|
||||||
self,
|
self,
|
||||||
weight,
|
weight,
|
||||||
|
|||||||
@ -14,6 +14,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
|||||||
"supports_preview_metadata": True,
|
"supports_preview_metadata": True,
|
||||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||||
"extension": {"manager": {"supports_v4": True}},
|
"extension": {"manager": {"supports_v4": True}},
|
||||||
|
"node_replacements": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,17 @@ class ComfyAPI_latest(ComfyAPIBase):
|
|||||||
VERSION = "latest"
|
VERSION = "latest"
|
||||||
STABLE = False
|
STABLE = False
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.node_replacement = self.NodeReplacement()
|
||||||
|
self.execution = self.Execution()
|
||||||
|
|
||||||
|
class NodeReplacement(ProxiedSingleton):
|
||||||
|
async def register(self, node_replace: io.NodeReplace) -> None:
|
||||||
|
"""Register a node replacement mapping."""
|
||||||
|
from server import PromptServer
|
||||||
|
PromptServer.instance.node_replace_manager.register(node_replace)
|
||||||
|
|
||||||
class Execution(ProxiedSingleton):
|
class Execution(ProxiedSingleton):
|
||||||
async def set_progress(
|
async def set_progress(
|
||||||
self,
|
self,
|
||||||
@ -73,8 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
|
|||||||
image=to_display,
|
image=to_display,
|
||||||
)
|
)
|
||||||
|
|
||||||
execution: Execution
|
|
||||||
|
|
||||||
class ComfyExtension(ABC):
|
class ComfyExtension(ABC):
|
||||||
async def on_load(self) -> None:
|
async def on_load(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -444,7 +444,7 @@ class VideoFromComponents(VideoInput):
|
|||||||
output.mux(packet)
|
output.mux(packet)
|
||||||
|
|
||||||
if audio_stream and self.__components.audio:
|
if audio_stream and self.__components.audio:
|
||||||
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout)
|
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().contiguous().numpy(), format='fltp', layout=layout)
|
||||||
frame.sample_rate = audio_sample_rate
|
frame.sample_rate = audio_sample_rate
|
||||||
frame.pts = 0
|
frame.pts = 0
|
||||||
output.mux(audio_stream.encode(frame))
|
output.mux(audio_stream.encode(frame))
|
||||||
|
|||||||
@ -73,8 +73,15 @@ class RemoteOptions:
|
|||||||
class NumberDisplay(str, Enum):
|
class NumberDisplay(str, Enum):
|
||||||
number = "number"
|
number = "number"
|
||||||
slider = "slider"
|
slider = "slider"
|
||||||
|
gradient_slider = "gradientslider"
|
||||||
|
|
||||||
|
|
||||||
|
class ControlAfterGenerate(str, Enum):
|
||||||
|
fixed = "fixed"
|
||||||
|
increment = "increment"
|
||||||
|
decrement = "decrement"
|
||||||
|
randomize = "randomize"
|
||||||
|
|
||||||
class _ComfyType(ABC):
|
class _ComfyType(ABC):
|
||||||
Type = Any
|
Type = Any
|
||||||
io_type: str = None
|
io_type: str = None
|
||||||
@ -263,7 +270,7 @@ class Int(ComfyTypeIO):
|
|||||||
class Input(WidgetInput):
|
class Input(WidgetInput):
|
||||||
'''Integer input.'''
|
'''Integer input.'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None,
|
||||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||||
self.min = min
|
self.min = min
|
||||||
@ -290,13 +297,15 @@ class Float(ComfyTypeIO):
|
|||||||
'''Float input.'''
|
'''Float input.'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
||||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None,
|
||||||
|
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||||
self.min = min
|
self.min = min
|
||||||
self.max = max
|
self.max = max
|
||||||
self.step = step
|
self.step = step
|
||||||
self.round = round
|
self.round = round
|
||||||
self.display_mode = display_mode
|
self.display_mode = display_mode
|
||||||
|
self.gradient_stops = gradient_stops
|
||||||
self.default: float
|
self.default: float
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
@ -306,6 +315,7 @@ class Float(ComfyTypeIO):
|
|||||||
"step": self.step,
|
"step": self.step,
|
||||||
"round": self.round,
|
"round": self.round,
|
||||||
"display": self.display_mode,
|
"display": self.display_mode,
|
||||||
|
"gradient_stops": self.gradient_stops,
|
||||||
})
|
})
|
||||||
|
|
||||||
@comfytype(io_type="STRING")
|
@comfytype(io_type="STRING")
|
||||||
@ -345,7 +355,7 @@ class Combo(ComfyTypeIO):
|
|||||||
tooltip: str=None,
|
tooltip: str=None,
|
||||||
lazy: bool=None,
|
lazy: bool=None,
|
||||||
default: str | int | Enum = None,
|
default: str | int | Enum = None,
|
||||||
control_after_generate: bool=None,
|
control_after_generate: bool | ControlAfterGenerate=None,
|
||||||
upload: UploadType=None,
|
upload: UploadType=None,
|
||||||
image_folder: FolderType=None,
|
image_folder: FolderType=None,
|
||||||
remote: RemoteOptions=None,
|
remote: RemoteOptions=None,
|
||||||
@ -389,7 +399,7 @@ class MultiCombo(ComfyTypeI):
|
|||||||
Type = list[str]
|
Type = list[str]
|
||||||
class Input(Combo.Input):
|
class Input(Combo.Input):
|
||||||
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool | ControlAfterGenerate=None,
|
||||||
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
|
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
|
||||||
self.multiselect = True
|
self.multiselect = True
|
||||||
@ -1203,6 +1213,30 @@ class Color(ComfyTypeIO):
|
|||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
return super().as_dict()
|
return super().as_dict()
|
||||||
|
|
||||||
|
@comfytype(io_type="BOUNDING_BOX")
|
||||||
|
class BoundingBox(ComfyTypeIO):
|
||||||
|
class BoundingBoxDict(TypedDict):
|
||||||
|
x: int
|
||||||
|
y: int
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
Type = BoundingBoxDict
|
||||||
|
|
||||||
|
class Input(WidgetInput):
|
||||||
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||||
|
socketless: bool=True, default: dict=None, component: str=None):
|
||||||
|
super().__init__(id, display_name, optional, tooltip, None, default, socketless)
|
||||||
|
self.component = component
|
||||||
|
if default is None:
|
||||||
|
self.default = {"x": 0, "y": 0, "width": 512, "height": 512}
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
d = super().as_dict()
|
||||||
|
if self.component:
|
||||||
|
d["component"] = self.component
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||||
@ -1309,6 +1343,7 @@ class NodeInfoV1:
|
|||||||
api_node: bool=None
|
api_node: bool=None
|
||||||
price_badge: dict | None = None
|
price_badge: dict | None = None
|
||||||
search_aliases: list[str]=None
|
search_aliases: list[str]=None
|
||||||
|
essentials_category: str=None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -1430,6 +1465,8 @@ class Schema:
|
|||||||
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
|
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
|
||||||
accept_all_inputs: bool=False
|
accept_all_inputs: bool=False
|
||||||
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
|
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
|
||||||
|
essentials_category: str | None = None
|
||||||
|
"""Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing')."""
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
'''Validate the schema:
|
'''Validate the schema:
|
||||||
@ -1536,6 +1573,7 @@ class Schema:
|
|||||||
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
|
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
|
||||||
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
|
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
|
||||||
search_aliases=self.search_aliases if self.search_aliases else None,
|
search_aliases=self.search_aliases if self.search_aliases else None,
|
||||||
|
essentials_category=self.essentials_category,
|
||||||
)
|
)
|
||||||
return info
|
return info
|
||||||
|
|
||||||
@ -2030,11 +2068,74 @@ class _UIOutput(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class InputMapOldId(TypedDict):
|
||||||
|
"""Map an old node input to a new node input by ID."""
|
||||||
|
new_id: str
|
||||||
|
old_id: str
|
||||||
|
|
||||||
|
class InputMapSetValue(TypedDict):
|
||||||
|
"""Set a specific value for a new node input."""
|
||||||
|
new_id: str
|
||||||
|
set_value: Any
|
||||||
|
|
||||||
|
InputMap = InputMapOldId | InputMapSetValue
|
||||||
|
"""
|
||||||
|
Input mapping for node replacement. Type is inferred by dictionary keys:
|
||||||
|
- {"new_id": str, "old_id": str} - maps old input to new input
|
||||||
|
- {"new_id": str, "set_value": Any} - sets a specific value for new input
|
||||||
|
"""
|
||||||
|
|
||||||
|
class OutputMap(TypedDict):
|
||||||
|
"""Map outputs of node replacement via indexes."""
|
||||||
|
new_idx: int
|
||||||
|
old_idx: int
|
||||||
|
|
||||||
|
class NodeReplace:
|
||||||
|
"""
|
||||||
|
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
|
||||||
|
|
||||||
|
Also supports assigning specific values to the input widgets of the new node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_node_id: The class name of the new replacement node.
|
||||||
|
old_node_id: The class name of the deprecated node.
|
||||||
|
old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
|
||||||
|
connected. The workflow JSON stores widget values by their relative position index,
|
||||||
|
not by ID. This list maps those positional indexes to input IDs, enabling the
|
||||||
|
replacement system to correctly identify widget values during node migration.
|
||||||
|
input_mapping: List of input mappings from old node to new node.
|
||||||
|
output_mapping: List of output mappings from old node to new node.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
new_node_id: str,
|
||||||
|
old_node_id: str,
|
||||||
|
old_widget_ids: list[str] | None=None,
|
||||||
|
input_mapping: list[InputMap] | None=None,
|
||||||
|
output_mapping: list[OutputMap] | None=None,
|
||||||
|
):
|
||||||
|
self.new_node_id = new_node_id
|
||||||
|
self.old_node_id = old_node_id
|
||||||
|
self.old_widget_ids = old_widget_ids
|
||||||
|
self.input_mapping = input_mapping
|
||||||
|
self.output_mapping = output_mapping
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
"""Create serializable representation of the node replacement."""
|
||||||
|
return {
|
||||||
|
"new_node_id": self.new_node_id,
|
||||||
|
"old_node_id": self.old_node_id,
|
||||||
|
"old_widget_ids": self.old_widget_ids,
|
||||||
|
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
|
||||||
|
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FolderType",
|
"FolderType",
|
||||||
"UploadType",
|
"UploadType",
|
||||||
"RemoteOptions",
|
"RemoteOptions",
|
||||||
"NumberDisplay",
|
"NumberDisplay",
|
||||||
|
"ControlAfterGenerate",
|
||||||
|
|
||||||
"comfytype",
|
"comfytype",
|
||||||
"Custom",
|
"Custom",
|
||||||
@ -2121,4 +2222,6 @@ __all__ = [
|
|||||||
"ImageCompare",
|
"ImageCompare",
|
||||||
"PriceBadgeDepends",
|
"PriceBadgeDepends",
|
||||||
"PriceBadge",
|
"PriceBadge",
|
||||||
|
"BoundingBox",
|
||||||
|
"NodeReplace",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -45,17 +45,55 @@ class BriaEditImageRequest(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BriaRemoveBackgroundRequest(BaseModel):
|
||||||
|
image: str = Field(...)
|
||||||
|
sync: bool = Field(False)
|
||||||
|
visual_input_content_moderation: bool = Field(
|
||||||
|
False, description="If true, returns 422 on input image moderation failure."
|
||||||
|
)
|
||||||
|
visual_output_content_moderation: bool = Field(
|
||||||
|
False, description="If true, returns 422 on visual output moderation failure."
|
||||||
|
)
|
||||||
|
seed: int = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class BriaStatusResponse(BaseModel):
|
class BriaStatusResponse(BaseModel):
|
||||||
request_id: str = Field(...)
|
request_id: str = Field(...)
|
||||||
status_url: str = Field(...)
|
status_url: str = Field(...)
|
||||||
warning: str | None = Field(None)
|
warning: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class BriaResult(BaseModel):
|
class BriaRemoveBackgroundResult(BaseModel):
|
||||||
|
image_url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class BriaRemoveBackgroundResponse(BaseModel):
|
||||||
|
status: str = Field(...)
|
||||||
|
result: BriaRemoveBackgroundResult | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class BriaImageEditResult(BaseModel):
|
||||||
structured_prompt: str = Field(...)
|
structured_prompt: str = Field(...)
|
||||||
image_url: str = Field(...)
|
image_url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class BriaResponse(BaseModel):
|
class BriaImageEditResponse(BaseModel):
|
||||||
status: str = Field(...)
|
status: str = Field(...)
|
||||||
result: BriaResult | None = Field(None)
|
result: BriaImageEditResult | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class BriaRemoveVideoBackgroundRequest(BaseModel):
|
||||||
|
video: str = Field(...)
|
||||||
|
background_color: str = Field(default="transparent", description="Background color for the output video.")
|
||||||
|
output_container_and_codec: str = Field(...)
|
||||||
|
preserve_audio: bool = Field(True)
|
||||||
|
seed: int = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class BriaRemoveVideoBackgroundResult(BaseModel):
|
||||||
|
video_url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class BriaRemoveVideoBackgroundResponse(BaseModel):
|
||||||
|
status: str = Field(...)
|
||||||
|
result: BriaRemoveVideoBackgroundResult | None = Field(None)
|
||||||
|
|||||||
@ -27,6 +27,7 @@ class Seedream4TaskCreationRequest(BaseModel):
|
|||||||
sequential_image_generation: str = Field("disabled")
|
sequential_image_generation: str = Field("disabled")
|
||||||
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
||||||
watermark: bool = Field(False)
|
watermark: bool = Field(False)
|
||||||
|
output_format: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ImageTaskCreationResponse(BaseModel):
|
class ImageTaskCreationResponse(BaseModel):
|
||||||
@ -106,6 +107,7 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
|||||||
("2496x1664 (3:2)", 2496, 1664),
|
("2496x1664 (3:2)", 2496, 1664),
|
||||||
("1664x2496 (2:3)", 1664, 2496),
|
("1664x2496 (2:3)", 1664, 2496),
|
||||||
("3024x1296 (21:9)", 3024, 1296),
|
("3024x1296 (21:9)", 3024, 1296),
|
||||||
|
("3072x3072 (1:1)", 3072, 3072),
|
||||||
("4096x4096 (1:1)", 4096, 4096),
|
("4096x4096 (1:1)", 4096, 4096),
|
||||||
("Custom", None, None),
|
("Custom", None, None),
|
||||||
]
|
]
|
||||||
|
|||||||
88
comfy_api_nodes/apis/elevenlabs.py
Normal file
88
comfy_api_nodes/apis/elevenlabs.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechToTextRequest(BaseModel):
|
||||||
|
model_id: str = Field(...)
|
||||||
|
cloud_storage_url: str = Field(...)
|
||||||
|
language_code: str | None = Field(None, description="ISO-639-1 or ISO-639-3 language code")
|
||||||
|
tag_audio_events: bool | None = Field(None, description="Annotate sounds like (laughter) in transcript")
|
||||||
|
num_speakers: int | None = Field(None, description="Max speakers predicted")
|
||||||
|
timestamps_granularity: str = Field(default="word", description="Timing precision: none, word, or character")
|
||||||
|
diarize: bool | None = Field(None, description="Annotate which speaker is talking")
|
||||||
|
diarization_threshold: float | None = Field(None, description="Speaker separation sensitivity")
|
||||||
|
temperature: float | None = Field(None, description="Randomness control")
|
||||||
|
seed: int = Field(..., description="Seed for deterministic sampling")
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechToTextWord(BaseModel):
|
||||||
|
text: str = Field(..., description="The word text")
|
||||||
|
type: str = Field(default="word", description="Type of text element (word, spacing, etc.)")
|
||||||
|
start: float | None = Field(None, description="Start time in seconds (when timestamps enabled)")
|
||||||
|
end: float | None = Field(None, description="End time in seconds (when timestamps enabled)")
|
||||||
|
speaker_id: str | None = Field(None, description="Speaker identifier when diarization is enabled")
|
||||||
|
logprob: float | None = Field(None, description="Log probability of the word")
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechToTextResponse(BaseModel):
|
||||||
|
language_code: str = Field(..., description="Detected or specified language code")
|
||||||
|
language_probability: float | None = Field(None, description="Confidence of language detection")
|
||||||
|
text: str = Field(..., description="Full transcript text")
|
||||||
|
words: list[SpeechToTextWord] | None = Field(None, description="Word-level timing information")
|
||||||
|
|
||||||
|
|
||||||
|
class TextToSpeechVoiceSettings(BaseModel):
|
||||||
|
stability: float | None = Field(None, description="Voice stability")
|
||||||
|
similarity_boost: float | None = Field(None, description="Similarity boost")
|
||||||
|
style: float | None = Field(None, description="Style exaggeration")
|
||||||
|
use_speaker_boost: bool | None = Field(None, description="Boost similarity to original speaker")
|
||||||
|
speed: float | None = Field(None, description="Speech speed")
|
||||||
|
|
||||||
|
|
||||||
|
class TextToSpeechRequest(BaseModel):
|
||||||
|
text: str = Field(..., description="Text to convert to speech")
|
||||||
|
model_id: str = Field(..., description="Model ID for TTS")
|
||||||
|
language_code: str | None = Field(None, description="ISO-639-1 or ISO-639-3 language code")
|
||||||
|
voice_settings: TextToSpeechVoiceSettings | None = Field(None, description="Voice settings")
|
||||||
|
seed: int = Field(..., description="Seed for deterministic sampling")
|
||||||
|
apply_text_normalization: str | None = Field(None, description="Text normalization mode: auto, on, off")
|
||||||
|
|
||||||
|
|
||||||
|
class TextToSoundEffectsRequest(BaseModel):
|
||||||
|
text: str = Field(..., description="Text prompt to convert into a sound effect")
|
||||||
|
duration_seconds: float = Field(..., description="Duration of generated sound in seconds")
|
||||||
|
prompt_influence: float = Field(..., description="How closely generation follows the prompt")
|
||||||
|
loop: bool | None = Field(None, description="Whether to create a smoothly looping sound effect")
|
||||||
|
|
||||||
|
|
||||||
|
class AddVoiceRequest(BaseModel):
|
||||||
|
name: str = Field(..., description="Name that identifies the voice")
|
||||||
|
remove_background_noise: bool = Field(..., description="Remove background noise from voice samples")
|
||||||
|
|
||||||
|
|
||||||
|
class AddVoiceResponse(BaseModel):
|
||||||
|
voice_id: str = Field(..., description="The newly created voice's unique identifier")
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechToSpeechRequest(BaseModel):
|
||||||
|
model_id: str = Field(..., description="Model ID for speech-to-speech")
|
||||||
|
voice_settings: str = Field(..., description="JSON string of voice settings")
|
||||||
|
seed: int = Field(..., description="Seed for deterministic sampling")
|
||||||
|
remove_background_noise: bool = Field(..., description="Remove background noise from input audio")
|
||||||
|
|
||||||
|
|
||||||
|
class DialogueInput(BaseModel):
|
||||||
|
text: str = Field(..., description="Text content to convert to speech")
|
||||||
|
voice_id: str = Field(..., description="Voice identifier for this dialogue segment")
|
||||||
|
|
||||||
|
|
||||||
|
class DialogueSettings(BaseModel):
|
||||||
|
stability: float | None = Field(None, description="Voice stability (0-1)")
|
||||||
|
|
||||||
|
|
||||||
|
class TextToDialogueRequest(BaseModel):
|
||||||
|
inputs: list[DialogueInput] = Field(..., description="List of dialogue segments")
|
||||||
|
model_id: str = Field(..., description="Model ID for dialogue generation")
|
||||||
|
language_code: str | None = Field(None, description="ISO-639-1 language code")
|
||||||
|
settings: DialogueSettings | None = Field(None, description="Voice settings")
|
||||||
|
seed: int | None = Field(None, description="Seed for deterministic sampling")
|
||||||
|
apply_text_normalization: str | None = Field(None, description="Text normalization mode: auto, on, off")
|
||||||
@ -116,9 +116,15 @@ class GeminiGenerationConfig(BaseModel):
|
|||||||
topP: float | None = Field(None, ge=0.0, le=1.0)
|
topP: float | None = Field(None, ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiImageOutputOptions(BaseModel):
|
||||||
|
mimeType: str = Field("image/png")
|
||||||
|
compressionQuality: int | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageConfig(BaseModel):
|
class GeminiImageConfig(BaseModel):
|
||||||
aspectRatio: str | None = Field(None)
|
aspectRatio: str | None = Field(None)
|
||||||
imageSize: str | None = Field(None)
|
imageSize: str | None = Field(None)
|
||||||
|
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||||
|
|||||||
@ -64,3 +64,23 @@ class To3DProTaskResultResponse(BaseModel):
|
|||||||
|
|
||||||
class To3DProTaskQueryRequest(BaseModel):
|
class To3DProTaskQueryRequest(BaseModel):
|
||||||
JobId: str = Field(...)
|
JobId: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class To3DUVFileInput(BaseModel):
|
||||||
|
Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
|
||||||
|
Url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class To3DUVTaskRequest(BaseModel):
|
||||||
|
File: To3DUVFileInput = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class TextureEditImageInfo(BaseModel):
|
||||||
|
Url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class TextureEditTaskRequest(BaseModel):
|
||||||
|
File3D: To3DUVFileInput = Field(...)
|
||||||
|
Image: TextureEditImageInfo | None = Field(None)
|
||||||
|
Prompt: str | None = Field(None)
|
||||||
|
EnablePBR: bool | None = Field(None)
|
||||||
|
|||||||
@ -134,6 +134,13 @@ class ImageToVideoWithAudioRequest(BaseModel):
|
|||||||
shot_type: str | None = Field(None)
|
shot_type: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class KlingAvatarRequest(BaseModel):
|
||||||
|
image: str = Field(...)
|
||||||
|
sound_file: str = Field(...)
|
||||||
|
prompt: str | None = Field(None)
|
||||||
|
mode: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class MotionControlRequest(BaseModel):
|
class MotionControlRequest(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
image_url: str = Field(...)
|
image_url: str = Field(...)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user