diff --git a/.github/workflows/release-webhook.yml b/.github/workflows/release-webhook.yml
index 6fceb7560..737e4c488 100644
--- a/.github/workflows/release-webhook.yml
+++ b/.github/workflows/release-webhook.yml
@@ -7,6 +7,8 @@ on:
jobs:
send-webhook:
runs-on: ubuntu-latest
+ env:
+ DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
steps:
- name: Send release webhook
env:
@@ -106,3 +108,37 @@ jobs:
--fail --silent --show-error
echo "✅ Release webhook sent successfully"
+
+ - name: Send repository dispatch to desktop
+ env:
+ DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
+ RELEASE_TAG: ${{ github.event.release.tag_name }}
+ RELEASE_URL: ${{ github.event.release.html_url }}
+ run: |
+ set -euo pipefail
+
+ if [ -z "${DISPATCH_TOKEN:-}" ]; then
+ echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
+ exit 1
+ fi
+
+ PAYLOAD="$(jq -n \
+ --arg release_tag "$RELEASE_TAG" \
+ --arg release_url "$RELEASE_URL" \
+ '{
+ event_type: "comfyui_release_published",
+ client_payload: {
+ release_tag: $release_tag,
+ release_url: $release_url
+ }
+ }')"
+
+ curl -fsSL \
+ -X POST \
+ -H "Accept: application/vnd.github+json" \
+ -H "Content-Type: application/json" \
+ -H "Authorization: Bearer ${DISPATCH_TOKEN}" \
+ https://api.github.com/repos/Comfy-Org/desktop/dispatches \
+ -d "$PAYLOAD"
+
+ echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"
diff --git a/.gitignore b/.gitignore
index 4e8cea71e..2700ad5c2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,7 +11,7 @@ extra_model_paths.yaml
/.vs
.vscode/
.idea/
-venv/
+venv*/
.venv/
/web/extensions/*
!/web/extensions/logging.js.example
diff --git a/README.md b/README.md
index 96dc2904b..3ccdc9c19 100644
--- a/README.md
+++ b/README.md
@@ -227,7 +227,7 @@ Put your VAE in: models/vae
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
-```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
+```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
diff --git a/app/node_replace_manager.py b/app/node_replace_manager.py
new file mode 100644
index 000000000..03b603c70
--- /dev/null
+++ b/app/node_replace_manager.py
@@ -0,0 +1,105 @@
+from __future__ import annotations
+
+from aiohttp import web
+
+from typing import TYPE_CHECKING, TypedDict
+if TYPE_CHECKING:
+ from comfy_api.latest._io_public import NodeReplace
+
+from comfy_execution.graph_utils import is_link
+import nodes
+
+class NodeStruct(TypedDict):
+ inputs: dict[str, str | int | float | bool | tuple[str, int]]
+ class_type: str
+ _meta: dict[str, str]
+
+def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
+ new_node_struct = node_struct.copy()
+ if empty_inputs:
+ new_node_struct["inputs"] = {}
+ else:
+ new_node_struct["inputs"] = node_struct["inputs"].copy()
+ new_node_struct["_meta"] = node_struct["_meta"].copy()
+ return new_node_struct
+
+
+class NodeReplaceManager:
+ """Manages node replacement registrations."""
+
+ def __init__(self):
+ self._replacements: dict[str, list[NodeReplace]] = {}
+
+ def register(self, node_replace: NodeReplace):
+ """Register a node replacement mapping."""
+ self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
+
+ def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
+ """Get replacements for an old node ID."""
+ return self._replacements.get(old_node_id)
+
+ def has_replacement(self, old_node_id: str) -> bool:
+ """Check if a replacement exists for an old node ID."""
+ return old_node_id in self._replacements
+
+ def apply_replacements(self, prompt: dict[str, NodeStruct]):
+ connections: dict[str, list[tuple[str, str, int]]] = {}
+ need_replacement: set[str] = set()
+ for node_number, node_struct in prompt.items():
+ class_type = node_struct["class_type"]
+ # need replacement if not in NODE_CLASS_MAPPINGS and has replacement
+ if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
+ need_replacement.add(node_number)
+ # keep track of connections
+ for input_id, input_value in node_struct["inputs"].items():
+ if is_link(input_value):
+ conn_number = input_value[0]
+ connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
+ for node_number in need_replacement:
+ node_struct = prompt[node_number]
+ class_type = node_struct["class_type"]
+ replacements = self.get_replacement(class_type)
+ if replacements is None:
+ continue
+ # just use the first replacement
+ replacement = replacements[0]
+ new_node_id = replacement.new_node_id
+ # if replacement is not a valid node, skip trying to replace it as will only cause confusion
+ if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
+ continue
+ # first, replace node id (class_type)
+ new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
+ new_node_struct["class_type"] = new_node_id
+ # TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
+ # second, replace inputs
+ if replacement.input_mapping is not None:
+ for input_map in replacement.input_mapping:
+ if "set_value" in input_map:
+ new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
+ elif "old_id" in input_map:
+ new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
+ # finalize input replacement
+ prompt[node_number] = new_node_struct
+ # third, replace outputs
+ if replacement.output_mapping is not None:
+ # re-mapping outputs requires changing the input values of nodes that receive connections from this one
+ if node_number in connections:
+ for conns in connections[node_number]:
+ conn_node_number, conn_input_id, old_output_idx = conns
+ for output_map in replacement.output_mapping:
+ if output_map["old_idx"] == old_output_idx:
+ new_output_idx = output_map["new_idx"]
+ previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
+ previous_input[1] = new_output_idx
+
+ def as_dict(self):
+ """Serialize all replacements to dict."""
+ return {
+ k: [v.as_dict() for v in v_list]
+ for k, v_list in self._replacements.items()
+ }
+
+ def add_routes(self, routes):
+ @routes.get("/node_replacements")
+ async def get_node_replacements(request):
+ return web.json_response(self.as_dict())
diff --git a/comfy/checkpoint_pickle.py b/comfy/checkpoint_pickle.py
deleted file mode 100644
index 206551d3c..000000000
--- a/comfy/checkpoint_pickle.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import pickle
-
-load = pickle.load
-
-class Empty:
- pass
-
-class Unpickler(pickle.Unpickler):
- def find_class(self, module, name):
- #TODO: safe unpickle
- if module.startswith("pytorch_lightning"):
- return Empty
- return super().find_class(module, name)
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index 9e1e704e0..ba670b16d 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -297,6 +297,30 @@ class ControlNet(ControlBase):
self.model_sampling_current = None
super().cleanup()
+
+class QwenFunControlNet(ControlNet):
+ def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
+ # Fun checkpoints are more sensitive to high strengths in the generic
+ # ControlNet merge path. Use a soft response curve so strength=1.0 stays
+ # unchanged while >1 grows more gently.
+ original_strength = self.strength
+ self.strength = math.sqrt(max(self.strength, 0.0))
+ try:
+ return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
+ finally:
+ self.strength = original_strength
+
+ def pre_run(self, model, percent_to_timestep_function):
+ super().pre_run(model, percent_to_timestep_function)
+ self.set_extra_arg("base_model", model.diffusion_model)
+
+ def copy(self):
+ c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
+ c.control_model = self.control_model
+ c.control_model_wrapped = self.control_model_wrapped
+ self.copy_to(c)
+ return c
+
class ControlLoraOps:
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
@@ -560,6 +584,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
+ sd = model_config.process_unet_state_dict(sd)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
@@ -605,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
+
+def load_controlnet_qwen_fun(sd, model_options={}):
+ load_device = comfy.model_management.get_torch_device()
+ weight_dtype = comfy.utils.weight_dtype(sd)
+ unet_dtype = model_options.get("dtype", weight_dtype)
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
+
+ operations = model_options.get("custom_operations", None)
+ if operations is None:
+ operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
+
+ in_features = sd["control_img_in.weight"].shape[1]
+ inner_dim = sd["control_img_in.weight"].shape[0]
+
+ block_weight = sd["control_blocks.0.attn.to_q.weight"]
+ attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
+ num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
+
+ model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
+ control_in_features=in_features,
+ inner_dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ num_control_blocks=5,
+ main_model_double=60,
+ injection_layers=(0, 12, 24, 36, 48),
+ operations=operations,
+ device=comfy.model_management.unet_offload_device(),
+ dtype=unet_dtype,
+ )
+ model = controlnet_load_state_dict(model, sd)
+
+ latent_format = comfy.latent_formats.Wan21()
+ control = QwenFunControlNet(
+ model,
+ compression_ratio=1,
+ latent_format=latent_format,
+ # Fun checkpoints already expect their own 33-channel context handling.
+ # Enabling generic concat_mask injects an extra mask channel at apply-time
+ # and breaks the intended fallback packing path.
+ concat_mask=False,
+ load_device=load_device,
+ manual_cast_dtype=manual_cast_dtype,
+ extra_conds=[],
+ )
+ return control
+
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
@@ -682,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
+ elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
+ return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py
index c0c51d51a..6978eb717 100644
--- a/comfy/k_diffusion/sampling.py
+++ b/comfy/k_diffusion/sampling.py
@@ -1,12 +1,11 @@
import math
-import time
from functools import partial
from scipy import integrate
import torch
from torch import nn
import torchsde
-from tqdm.auto import trange as trange_, tqdm
+from tqdm.auto import tqdm
from . import utils
from . import deis
@@ -15,34 +14,7 @@ import comfy.model_patcher
import comfy.model_sampling
import comfy.memory_management
-
-
-def trange(*args, **kwargs):
- if comfy.memory_management.aimdo_allocator is None:
- return trange_(*args, **kwargs)
-
- pbar = trange_(*args, **kwargs, smoothing=1.0)
- pbar._i = 0
- pbar.set_postfix_str(" Model Initializing ... ")
-
- _update = pbar.update
-
- def warmup_update(n=1):
- pbar._i += 1
- if pbar._i == 1:
- pbar.i1_time = time.time()
- pbar.set_postfix_str(" Model Initialization complete! ")
- elif pbar._i == 2:
- #bring forward the effective start time based the the diff between first and second iteration
- #to attempt to remove load overhead from the final step rate estimate.
- pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
- pbar.set_postfix_str("")
-
- _update(n)
-
- pbar.update = warmup_update
- return pbar
-
+from comfy.utils import model_trange as trange
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py
index 69338336d..1d7dc59a8 100644
--- a/comfy/ldm/ace/ace_step15.py
+++ b/comfy/ldm/ace/ace_step15.py
@@ -1110,7 +1110,7 @@ class AceStepConditionGenerationModel(nn.Module):
return encoder_hidden, encoder_mask, context_latents
- def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, **kwargs):
+ def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, replace_with_null_embeds=False, **kwargs):
text_attention_mask = None
lyric_attention_mask = None
refer_audio_order_mask = None
@@ -1140,6 +1140,9 @@ class AceStepConditionGenerationModel(nn.Module):
src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes
)
+ if replace_with_null_embeds:
+ enc_hidden[:] = self.null_condition_emb.to(enc_hidden)
+
out = self.decoder(hidden_states=x,
timestep=timestep,
timestep_r=timestep,
diff --git a/comfy/ldm/anima/model.py b/comfy/ldm/anima/model.py
index 2e6ed58fa..6fcf8df90 100644
--- a/comfy/ldm/anima/model.py
+++ b/comfy/ldm/anima/model.py
@@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
if source_attention_mask.ndim == 2:
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
- x = self.in_proj(self.embed(target_input_ids))
context = source_hidden_states
+ x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
position_embeddings = self.rotary_emb(x, position_ids)
@@ -195,8 +195,20 @@ class Anima(MiniTrainDIT):
super().__init__(*args, **kwargs)
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
- def preprocess_text_embeds(self, text_embeds, text_ids):
+ def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
if text_ids is not None:
- return self.llm_adapter(text_embeds, text_ids)
+ out = self.llm_adapter(text_embeds, text_ids)
+ if t5xxl_weights is not None:
+ out = out * t5xxl_weights
+
+ if out.shape[1] < 512:
+ out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
+ return out
else:
return text_embeds
+
+ def forward(self, x, timesteps, context, **kwargs):
+ t5xxl_ids = kwargs.pop("t5xxl_ids", None)
+ if t5xxl_ids is not None:
+ context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
+ return super().forward(x, timesteps, context, **kwargs)
diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py
index 2d5684348..df348a8ed 100644
--- a/comfy/ldm/chroma/layers.py
+++ b/comfy/ldm/chroma/layers.py
@@ -3,7 +3,6 @@ from torch import Tensor, nn
from comfy.ldm.flux.layers import (
MLPEmbedder,
- RMSNorm,
ModulationOut,
)
@@ -29,7 +28,7 @@ class Approximator(nn.Module):
super().__init__()
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
- self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
+ self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
@property
diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py
index 2e8ef0687..9fd865f20 100644
--- a/comfy/ldm/chroma/model.py
+++ b/comfy/ldm/chroma/model.py
@@ -152,6 +152,7 @@ class Chroma(nn.Module):
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
+ transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {})
# running on sequences img
@@ -228,6 +229,7 @@ class Chroma(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
+ transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if i not in self.skip_dit:
diff --git a/comfy/ldm/chroma_radiance/layers.py b/comfy/ldm/chroma_radiance/layers.py
index 3c7bc9b6b..08d31e0ba 100644
--- a/comfy/ldm/chroma_radiance/layers.py
+++ b/comfy/ldm/chroma_radiance/layers.py
@@ -4,8 +4,6 @@ from functools import lru_cache
import torch
from torch import nn
-from comfy.ldm.flux.layers import RMSNorm
-
class NerfEmbedder(nn.Module):
"""
@@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
# We now need to generate parameters for 3 matrices.
total_params = 3 * hidden_size_x**2 * mlp_ratio
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
- self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
+ self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
self.mlp_ratio = mlp_ratio
@@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
class NerfFinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
- self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
+ self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
class NerfFinalLayerConv(nn.Module):
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
- self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
+ self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.conv = operations.Conv2d(
in_channels=hidden_size,
out_channels=out_channels,
diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py
index c270e6333..2268bff38 100644
--- a/comfy/ldm/cosmos/predict2.py
+++ b/comfy/ldm/cosmos/predict2.py
@@ -335,7 +335,7 @@ class FinalLayer(nn.Module):
device=None, dtype=None, operations=None
):
super().__init__()
- self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = operations.Linear(
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
)
@@ -463,6 +463,8 @@ class Block(nn.Module):
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
+ residual_dtype = x_B_T_H_W_D.dtype
+ compute_dtype = emb_B_T_D.dtype
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
@@ -512,7 +514,7 @@ class Block(nn.Module):
result_B_T_H_W_D = rearrange(
self.self_attn(
# normalized_x_B_T_HW_D,
- rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
+ rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
None,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
@@ -522,7 +524,7 @@ class Block(nn.Module):
h=H,
w=W,
)
- x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
+ x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
def _x_fn(
_x_B_T_H_W_D: torch.Tensor,
@@ -536,7 +538,7 @@ class Block(nn.Module):
)
_result_B_T_H_W_D = rearrange(
self.cross_attn(
- rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
+ rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
@@ -555,7 +557,7 @@ class Block(nn.Module):
shift_cross_attn_B_T_1_1_D,
transformer_options=transformer_options,
)
- x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
+ x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D,
@@ -563,8 +565,8 @@ class Block(nn.Module):
scale_mlp_B_T_1_1_D,
shift_mlp_B_T_1_1_D,
)
- result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
- x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
+ result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
+ x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
return x_B_T_H_W_D
@@ -876,6 +878,14 @@ class MiniTrainDIT(nn.Module):
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
"transformer_options": kwargs.get("transformer_options", {}),
}
+
+ # The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
+ # in fp32, but run attention and MLP modules in fp16.
+ # An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
+ # quality degradation and visual artifacts.
+ if x_B_T_H_W_D.dtype == torch.float16:
+ x_B_T_H_W_D = x_B_T_H_W_D.float()
+
for block in self.blocks:
x_B_T_H_W_D = block(
x_B_T_H_W_D,
@@ -884,6 +894,6 @@ class MiniTrainDIT(nn.Module):
**block_kwargs,
)
- x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
+ x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
return x_B_C_Tt_Hp_Wp
diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py
index 60f2bdae2..8b3f500d7 100644
--- a/comfy/ldm/flux/layers.py
+++ b/comfy/ldm/flux/layers.py
@@ -5,9 +5,9 @@ import torch
from torch import Tensor, nn
from .math import attention, rope
-import comfy.ops
-import comfy.ldm.common_dit
+# Fix import for some custom nodes, TODO: delete eventually.
+RMSNorm = None
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list):
@@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
-class RMSNorm(torch.nn.Module):
- def __init__(self, dim: int, dtype=None, device=None, operations=None):
- super().__init__()
- self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
-
- def forward(self, x: Tensor):
- return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
-
class QKNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
- self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
- self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
+ self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
+ self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q)
@@ -169,7 +161,7 @@ class SiLUActivation(nn.Module):
class DoubleStreamBlock(nn.Module):
- def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module):
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
- self.flipped_img_txt = flipped_img_txt
-
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
if self.modulation:
img_mod1, img_mod2 = self.img_mod(vec)
@@ -206,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
else:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
+ transformer_patches = transformer_options.get("patches", {})
+ extra_options = transformer_options.copy()
+
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
@@ -224,32 +217,23 @@ class DoubleStreamBlock(nn.Module):
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
- if self.flipped_img_txt:
- q = torch.cat((img_q, txt_q), dim=2)
- del img_q, txt_q
- k = torch.cat((img_k, txt_k), dim=2)
- del img_k, txt_k
- v = torch.cat((img_v, txt_v), dim=2)
- del img_v, txt_v
- # run actual attention
- attn = attention(q, k, v,
- pe=pe, mask=attn_mask, transformer_options=transformer_options)
- del q, k, v
+ q = torch.cat((txt_q, img_q), dim=2)
+ del txt_q, img_q
+ k = torch.cat((txt_k, img_k), dim=2)
+ del txt_k, img_k
+ v = torch.cat((txt_v, img_v), dim=2)
+ del txt_v, img_v
+ # run actual attention
+ attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
+ del q, k, v
- img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
- else:
- q = torch.cat((txt_q, img_q), dim=2)
- del txt_q, img_q
- k = torch.cat((txt_k, img_k), dim=2)
- del txt_k, img_k
- v = torch.cat((txt_v, img_v), dim=2)
- del txt_v, img_v
- # run actual attention
- attn = attention(q, k, v,
- pe=pe, mask=attn_mask, transformer_options=transformer_options)
- del q, k, v
+ if "attn1_output_patch" in transformer_patches:
+ extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
+ patch = transformer_patches["attn1_output_patch"]
+ for p in patch:
+ attn = p(attn, extra_options)
- txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
@@ -328,6 +312,9 @@ class SingleStreamBlock(nn.Module):
else:
mod = vec
+ transformer_patches = transformer_options.get("patches", {})
+ extra_options = transformer_options.copy()
+
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@@ -337,6 +324,12 @@ class SingleStreamBlock(nn.Module):
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
+
+ if "attn1_output_patch" in transformer_patches:
+ patch = transformer_patches["attn1_output_patch"]
+ for p in patch:
+ attn = p(attn, extra_options)
+
# compute activation in mlp stream, cat again and run second linear layer
if self.yak_mlp:
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py
index f9597de5b..5e764bb46 100644
--- a/comfy/ldm/flux/math.py
+++ b/comfy/ldm/flux/math.py
@@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
return out.to(dtype=torch.float32, device=pos.device)
+def _apply_rope1(x: Tensor, freqs_cis: Tensor):
+ x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
+
+ x_out = freqs_cis[..., 0] * x_[..., 0]
+ x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
+
+ return x_out.reshape(*x.shape).type_as(x)
+
+
+def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
+ return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
+
+
try:
import comfy.quant_ops
- apply_rope = comfy.quant_ops.ck.apply_rope
- apply_rope1 = comfy.quant_ops.ck.apply_rope1
+ q_apply_rope = comfy.quant_ops.ck.apply_rope
+ q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
+ def apply_rope(xq, xk, freqs_cis):
+ if comfy.model_management.in_training:
+ return _apply_rope(xq, xk, freqs_cis)
+ else:
+ return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
+ def apply_rope1(x, freqs_cis):
+ if comfy.model_management.in_training:
+ return _apply_rope1(x, freqs_cis)
+ else:
+ return q_apply_rope1(x, freqs_cis)
except:
logging.warning("No comfy kitchen, using old apply_rope functions.")
- def apply_rope1(x: Tensor, freqs_cis: Tensor):
- x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
-
- x_out = freqs_cis[..., 0] * x_[..., 0]
- x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
-
- return x_out.reshape(*x.shape).type_as(x)
-
- def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
- return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
+ apply_rope = _apply_rope
+ apply_rope1 = _apply_rope1
diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py
index f40c2a7a9..ef4dcf7c5 100644
--- a/comfy/ldm/flux/model.py
+++ b/comfy/ldm/flux/model.py
@@ -16,7 +16,6 @@ from .layers import (
SingleStreamBlock,
timestep_embedding,
Modulation,
- RMSNorm
)
@dataclass
@@ -81,7 +80,7 @@ class Flux(nn.Module):
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
if params.txt_norm:
- self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
+ self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
else:
self.txt_norm = None
@@ -143,6 +142,7 @@ class Flux(nn.Module):
attn_mask: Tensor = None,
) -> Tensor:
+ transformer_options = transformer_options.copy()
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
@@ -232,6 +232,7 @@ class Flux(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
+ transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py
index 55ab550f8..b94cdfa87 100644
--- a/comfy/ldm/hunyuan_video/model.py
+++ b/comfy/ldm/hunyuan_video/model.py
@@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
- flipped_img_txt=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -305,6 +304,7 @@ class HunyuanVideo(nn.Module):
control=None,
transformer_options={},
) -> Tensor:
+ transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {})
initial_shape = list(img.shape)
@@ -378,14 +378,14 @@ class HunyuanVideo(nn.Module):
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
- ids = torch.cat((img_ids, txt_ids), dim=1)
+ ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
img_len = img.shape[1]
if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
- attn_mask[:, 0, img_len:] = txt_mask
+ attn_mask[:, 0, :txt.shape[1]] = txt_mask
else:
attn_mask = None
@@ -413,10 +413,11 @@ class HunyuanVideo(nn.Module):
if add is not None:
img += add
- img = torch.cat((img, txt), 1)
+ img = torch.cat((txt, img), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
+ transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
@@ -435,9 +436,9 @@ class HunyuanVideo(nn.Module):
if i < len(control_o):
add = control_o[i]
if add is not None:
- img[:, : img_len] += add
+ img[:, txt.shape[1]: img_len + txt.shape[1]] += add
- img = img[:, : img_len]
+ img = img[:, txt.shape[1]: img_len + txt.shape[1]]
if ref_latent is not None:
img = img[:, ref_latent.shape[1]:]
diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py
index 5a22ef030..805592aa5 100644
--- a/comfy/ldm/modules/diffusionmodules/model.py
+++ b/comfy/ldm/modules/diffusionmodules/model.py
@@ -102,19 +102,7 @@ class VideoConv3d(nn.Module):
return self.conv(x)
def interpolate_up(x, scale_factor):
- try:
- return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
- except: #operation not implemented for bf16
- orig_shape = list(x.shape)
- out_shape = orig_shape[:2]
- for i in range(len(orig_shape) - 2):
- out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
- out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
- split = 8
- l = out.shape[1] // split
- for i in range(0, out.shape[1], l):
- out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
- return out
+ return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py
index a6d408104..c0aae9240 100644
--- a/comfy/ldm/qwen_image/controlnet.py
+++ b/comfy/ldm/qwen_image/controlnet.py
@@ -2,6 +2,196 @@ import torch
import math
from .model import QwenImageTransformer2DModel
+from .model import QwenImageTransformerBlock
+
+
+class QwenImageFunControlBlock(QwenImageTransformerBlock):
+ def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
+ super().__init__(
+ dim=dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+ self.has_before_proj = has_before_proj
+ if has_before_proj:
+ self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
+ self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
+
+
+class QwenImageFunControlNetModel(torch.nn.Module):
+ def __init__(
+ self,
+ control_in_features=132,
+ inner_dim=3072,
+ num_attention_heads=24,
+ attention_head_dim=128,
+ num_control_blocks=5,
+ main_model_double=60,
+ injection_layers=(0, 12, 24, 36, 48),
+ dtype=None,
+ device=None,
+ operations=None,
+ ):
+ super().__init__()
+ self.dtype = dtype
+ self.main_model_double = main_model_double
+ self.injection_layers = tuple(injection_layers)
+ # Keep base hint scaling at 1.0 so user-facing strength behaves similarly
+ # to the reference Gen2/VideoX implementation around strength=1.
+ self.hint_scale = 1.0
+ self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
+
+ self.control_blocks = torch.nn.ModuleList([])
+ for i in range(num_control_blocks):
+ self.control_blocks.append(
+ QwenImageFunControlBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ has_before_proj=(i == 0),
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+ )
+
+ def _process_hint_tokens(self, hint):
+ if hint is None:
+ return None
+ if hint.ndim == 4:
+ hint = hint.unsqueeze(2)
+
+ # Fun checkpoints are trained with 33 latent channels before 2x2 packing:
+ # [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
+ # Default behavior (no inpaint input in stock Apply ControlNet) should use
+ # zeros for mask/inpaint branches, matching VideoX fallback semantics.
+ expected_c = self.control_img_in.weight.shape[1] // 4
+ if hint.shape[1] == 16 and expected_c == 33:
+ zeros_mask = torch.zeros_like(hint[:, :1])
+ zeros_inpaint = torch.zeros_like(hint)
+ hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
+
+ bs, c, t, h, w = hint.shape
+ hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
+ orig_shape = hidden_states.shape
+ hidden_states = hidden_states.view(
+ orig_shape[0],
+ orig_shape[1],
+ orig_shape[-3],
+ orig_shape[-2] // 2,
+ 2,
+ orig_shape[-1] // 2,
+ 2,
+ )
+ hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
+ hidden_states = hidden_states.reshape(
+ bs,
+ t * ((h + 1) // 2) * ((w + 1) // 2),
+ c * 4,
+ )
+
+ expected_in = self.control_img_in.weight.shape[1]
+ cur_in = hidden_states.shape[-1]
+ if cur_in < expected_in:
+ pad = torch.zeros(
+ (hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
+ device=hidden_states.device,
+ dtype=hidden_states.dtype,
+ )
+ hidden_states = torch.cat([hidden_states, pad], dim=-1)
+ elif cur_in > expected_in:
+ hidden_states = hidden_states[:, :, :expected_in]
+
+ return hidden_states
+
+ def forward(
+ self,
+ x,
+ timesteps,
+ context,
+ attention_mask=None,
+ guidance: torch.Tensor = None,
+ hint=None,
+ transformer_options={},
+ base_model=None,
+ **kwargs,
+ ):
+ if base_model is None:
+ raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
+
+ encoder_hidden_states_mask = attention_mask
+ # Keep attention mask disabled inside Fun control blocks to mirror
+ # VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
+ encoder_hidden_states_mask = None
+
+ hidden_states, img_ids, _ = base_model.process_img(x)
+ hint_tokens = self._process_hint_tokens(hint)
+ if hint_tokens is None:
+ raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
+
+ if hint_tokens.shape[1] != hidden_states.shape[1]:
+ max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
+ hint_tokens = hint_tokens[:, :max_tokens]
+ hidden_states = hidden_states[:, :max_tokens]
+ img_ids = img_ids[:, :max_tokens]
+
+ txt_start = round(
+ max(
+ ((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
+ ((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
+ )
+ )
+ txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
+ ids = torch.cat((txt_ids, img_ids), dim=1)
+ image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
+
+ hidden_states = base_model.img_in(hidden_states)
+ encoder_hidden_states = base_model.txt_norm(context)
+ encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
+
+ if guidance is not None:
+ guidance = guidance * 1000
+
+ temb = (
+ base_model.time_text_embed(timesteps, hidden_states)
+ if guidance is None
+ else base_model.time_text_embed(timesteps, guidance, hidden_states)
+ )
+
+ c = self.control_img_in(hint_tokens)
+
+ for i, block in enumerate(self.control_blocks):
+ if i == 0:
+ c_in = block.before_proj(c) + hidden_states
+ all_c = []
+ else:
+ all_c = list(torch.unbind(c, dim=0))
+ c_in = all_c.pop(-1)
+
+ encoder_hidden_states, c_out = block(
+ hidden_states=c_in,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ transformer_options=transformer_options,
+ )
+
+ c_skip = block.after_proj(c_out) * self.hint_scale
+ all_c += [c_skip, c_out]
+ c = torch.stack(all_c, dim=0)
+
+ hints = torch.unbind(c, dim=0)[:-1]
+
+ controlnet_block_samples = [None] * self.main_model_double
+ for local_idx, base_idx in enumerate(self.injection_layers):
+ if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
+ controlnet_block_samples[base_idx] = hints[local_idx]
+
+ return {"input": controlnet_block_samples}
class QwenImageControlNetModel(QwenImageTransformer2DModel):
diff --git a/comfy/lora.py b/comfy/lora.py
index 44030bcab..279cf38bb 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -374,6 +374,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
return padded_tensor
+def calculate_shape(patches, weight, key, original_weights=None):
+ current_shape = weight.shape
+
+ for p in patches:
+ v = p[1]
+ offset = p[3]
+
+ # Offsets restore the old shape; lists force a diff without metadata
+ if offset is not None or isinstance(v, list):
+ continue
+
+ if isinstance(v, weight_adapter.WeightAdapterBase):
+ adapter_shape = v.calculate_shape(key)
+ if adapter_shape is not None:
+ current_shape = adapter_shape
+ continue
+
+ # Standard diff logic with padding
+ if len(v) == 2:
+ patch_type, patch_data = v[0], v[1]
+ if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
+ current_shape = patch_data[0].shape
+
+ return current_shape
+
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
for p in patches:
strength = p[0]
diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py
index 9d8d21efe..749e81df3 100644
--- a/comfy/lora_convert.py
+++ b/comfy/lora_convert.py
@@ -5,7 +5,7 @@ import comfy.utils
def convert_lora_bfl_control(sd): #BFL loras for Flux
sd_out = {}
for k in sd:
- k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
+ k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
sd_out[k_to] = sd[k]
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 3bb54f59e..9dcef8741 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -178,10 +178,7 @@ class BaseModel(torch.nn.Module):
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
context = c_crossattn
- dtype = self.get_dtype()
-
- if self.manual_cast_dtype is not None:
- dtype = self.manual_cast_dtype
+ dtype = self.get_dtype_inference()
xc = xc.to(dtype)
device = xc.device
@@ -218,6 +215,13 @@ class BaseModel(torch.nn.Module):
def get_dtype(self):
return self.diffusion_model.dtype
+ def get_dtype_inference(self):
+ dtype = self.get_dtype()
+
+ if self.manual_cast_dtype is not None:
+ dtype = self.manual_cast_dtype
+ return dtype
+
def encode_adm(self, **kwargs):
return None
@@ -372,9 +376,7 @@ class BaseModel(torch.nn.Module):
input_shapes += shape
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
- dtype = self.get_dtype()
- if self.manual_cast_dtype is not None:
- dtype = self.manual_cast_dtype
+ dtype = self.get_dtype_inference()
#TODO: this needs to be tweaked
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
@@ -1160,12 +1162,16 @@ class Anima(BaseModel):
device = kwargs["device"]
if cross_attn is not None:
if t5xxl_ids is not None:
- cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
if t5xxl_weights is not None:
- cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
+ t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
+ t5xxl_ids = t5xxl_ids.unsqueeze(0)
+
+ if torch.is_inference_mode_enabled(): # if not we are training
+ cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
+ else:
+ out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
+ out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
- if cross_attn.shape[1] < 512:
- cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
@@ -1552,6 +1558,8 @@ class ACEStep15(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
+ if torch.count_nonzero(cross_attn) == 0:
+ out['replace_with_null_embeds'] = comfy.conds.CONDConstant(True)
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
@@ -1575,6 +1583,10 @@ class ACEStep15(BaseModel):
else:
out['is_covers'] = comfy.conds.CONDConstant(False)
+ if refer_audio.shape[2] < noise.shape[2]:
+ pad = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
+ refer_audio = torch.cat([refer_audio.to(pad), pad[:, :, refer_audio.shape[2]:]], dim=2)
+
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
return out
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index e8ad725df..30ea03e8e 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count
+def any_suffix_in(keys, prefix, main, suffix_list=[]):
+ for x in suffix_list:
+ if "{}{}{}".format(prefix, main, x) in keys:
+ return True
+ return False
+
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None
use_linear_in_transformer = False
@@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["meanflow_sum"] = False
return dit_config
- if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
+ if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {}
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "flux2"
@@ -241,7 +247,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
- if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
+
+ if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64
dit_config["out_channels"] = 64
@@ -249,7 +256,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5
- if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
+
+ if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3
dit_config["out_channels"] = 3
@@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 512
- dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
+ dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
@@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
- dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
+ dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
dit_config["txt_ids_dims"] = [1, 2]
diff --git a/comfy/model_management.py b/comfy/model_management.py
index b6291f340..38c3e482b 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -19,7 +19,7 @@
import psutil
import logging
from enum import Enum
-from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
+from comfy.cli_args import args, PerformanceFeature
import threading
import torch
import sys
@@ -55,6 +55,11 @@ cpu_state = CPUState.GPU
total_vram = 0
+
+# Training Related State
+in_training = False
+
+
def get_supported_float8_types():
float8_types = []
try:
@@ -651,7 +656,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
soft_empty_cache()
return unloaded_models
-def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
+def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
@@ -747,26 +752,6 @@ def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, m
current_loaded_models.insert(0, loaded_model)
return
-def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
- with torch.inference_mode():
- load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
- soft_empty_cache()
-
-def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
- #Deliberately load models outside of the Aimdo mempool so they can be retained accross
- #nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
- #thread local. So exploit that to escape context
- if enables_dynamic_vram():
- t = threading.Thread(
- target=load_models_gpu_thread,
- args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
- )
- t.start()
- t.join()
- else:
- load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
- minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
-
def load_model_gpu(model):
return load_models_gpu([model])
@@ -1226,21 +1211,20 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if dtype is None:
dtype = weight._model_dtype
- r = torch.empty_like(weight, dtype=dtype, device=device)
-
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
- raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
- v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
- if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
+ if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
+ v_tensor = weight._v_tensor
+ else:
+ raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
+ v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
+ weight._v_tensor = v_tensor
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)
- #always take a deep copy even if _v is good, as we have no reasonable point to unpin
- #a non comfy weight
- r.copy_(v_tensor)
- comfy_aimdo.model_vbar.vbar_unpin(weight._v)
- return r
+ return v_tensor.to(dtype=dtype)
+
+ r = torch.empty_like(weight, dtype=dtype, device=device)
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
#Offloaded casting could skip this, however it would make the quantizations
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index d888dbcfb..21b4ce53e 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -19,7 +19,6 @@
from __future__ import annotations
import collections
-import copy
import inspect
import logging
import math
@@ -317,7 +316,7 @@ class ModelPatcher:
n.object_patches = self.object_patches.copy()
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
- n.model_options = copy.deepcopy(self.model_options)
+ n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
@@ -407,13 +406,16 @@ class ModelPatcher:
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
+ def disable_model_cfg1_optimization(self):
+ self.model_options["disable_cfg1_optimization"] = True
+
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
- self.model_options["disable_cfg1_optimization"] = True
+ self.disable_model_cfg1_optimization()
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
@@ -680,18 +682,19 @@ class ModelPatcher:
for key in list(self.pinned):
self.unpin_weight(key)
- def _load_list(self, prio_comfy_cast_weights=False):
+ def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
loading = []
for n, m in self.model.named_modules():
- params = []
- skip = False
- for name, param in m.named_parameters(recurse=False):
- params.append(name)
+ default = False
+ params = { name: param for name, param in m.named_parameters(recurse=False) }
for name, param in m.named_parameters(recurse=True):
if name not in params:
- skip = True # skip random weights in non leaf modules
+ default = True # default random weights in non leaf modules
break
- if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
+ if default and default_device is not None:
+ for param in params.values():
+ param.data = param.data.to(device=default_device)
+ if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
@@ -1492,9 +1495,11 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None:
vbar.prioritize()
- #We have way more tools for acceleration on comfy weight offloading, so always
+ #We force reserve VRAM for the non comfy-weight so we dont have to deal
+ #with pin and unpin syncrhonization which can be expensive for small weights
+ #with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse).
- loading = self._load_list(prio_comfy_cast_weights=True)
+ loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
loading.sort(reverse=True)
for x in loading:
@@ -1512,8 +1517,10 @@ class ModelPatcherDynamic(ModelPatcher):
weight, _, _ = get_key_weight(self.model, key)
if weight is None:
- return 0
+ return (False, 0)
if key in self.patches:
+ if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
+ return (True, 0)
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
num_patches += 1
else:
@@ -1524,10 +1531,16 @@ class ModelPatcherDynamic(ModelPatcher):
setattr(m, param_key + "_function", weight_function)
geometry = weight
if not isinstance(weight, QuantizedTensor):
- model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
+ model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
- return comfy.memory_management.vram_aligned_size(geometry)
+ return (False, comfy.memory_management.vram_aligned_size(geometry))
+
+ def force_load_param(self, param_key, device_to):
+ key = key_param_name_to_key(n, param_key)
+ if key in self.backup:
+ comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
+ self.patch_weight_to_device(key, device_to=device_to)
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
@@ -1535,13 +1548,19 @@ class ModelPatcherDynamic(ModelPatcher):
m.seed_key = n
set_dirty(m, dirty)
- v_weight_size = 0
- v_weight_size += setup_param(self, m, n, "weight")
- v_weight_size += setup_param(self, m, n, "bias")
+ force_load, v_weight_size = setup_param(self, m, n, "weight")
+ force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
+ force_load = force_load or force_load_bias
+ v_weight_size += v_weight_bias
- if vbar is not None and not hasattr(m, "_v"):
- m._v = vbar.alloc(v_weight_size)
- allocated_size += v_weight_size
+ if force_load:
+ logging.info(f"Module {n} has resizing Lora - force loading")
+ force_load_param(self, "weight", device_to)
+ force_load_param(self, "bias", device_to)
+ else:
+ if vbar is not None and not hasattr(m, "_v"):
+ m._v = vbar.alloc(v_weight_size)
+ allocated_size += v_weight_size
else:
for param in params:
@@ -1550,13 +1569,16 @@ class ModelPatcherDynamic(ModelPatcher):
weight.seed_key = key
set_dirty(weight, dirty)
geometry = weight
- model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
+ model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
weight._model_dtype = model_dtype
allocated_size += weight_size
+ vbar.set_watermark_limit(allocated_size)
+
+ move_weight_functions(m, device_to)
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
@@ -1577,7 +1599,7 @@ class ModelPatcherDynamic(ModelPatcher):
return 0 if vbar is None else vbar.free_memory(memory_to_free)
def partially_unload_ram(self, ram_to_unload):
- loading = self._load_list(prio_comfy_cast_weights=True)
+ loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
for x in loading:
_, _, _, _, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
@@ -1598,6 +1620,13 @@ class ModelPatcherDynamic(ModelPatcher):
if unpatch_weights:
self.partially_unload_ram(1e32)
self.partially_unload(None, 1e32)
+ for m in self.model.modules():
+ move_weight_functions(m, device_to)
+
+ keys = list(self.backup.keys())
+ for k in keys:
+ bk = self.backup[k]
+ comfy.utils.set_attr_param(self.model, k, bk.weight)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above
diff --git a/comfy/ops.py b/comfy/ops.py
index 0f4eca7c7..a6c642795 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -21,7 +21,6 @@ import logging
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import comfy.float
-import comfy.rmsnorm
import json
import comfy.memory_management
import comfy.pinned_memory
@@ -80,17 +79,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
-def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
+def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
offload_stream = None
xfer_dest = None
- cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
- if signature is not None:
- xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
+ if signature is not None:
+ if resident:
+ weight = s._v_weight
+ bias = s._v_bias
+ else:
+ xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
if not resident:
+ cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None
xfer_source = [ s.weight, s.bias ]
@@ -140,9 +143,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
- params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
- weight = params[0]
- bias = params[1]
+ params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
+ weight = params[0]
+ bias = params[1]
+ if signature is not None:
+ s._v_weight = weight
+ s._v_bias = bias
+ s._v_signature=signature
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
@@ -163,14 +170,14 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
x = lowvram_fn(x)
if (isinstance(orig, QuantizedTensor) and
- (orig.dtype == dtype and len(fns) == 0 or update_weight)):
+ (want_requant and len(fns) == 0 or update_weight)):
seed = comfy.utils.string_to_seed(s.seed_key)
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
- if orig.dtype == dtype and len(fns) == 0:
+ if want_requant and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
- else:
- y = x
+ elif update_weight:
+ y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
if update_weight:
orig.copy_(y)
for f in fns:
@@ -182,13 +189,12 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
- s._v_signature=signature
#FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None)
-def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
+def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
@@ -206,7 +212,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
- return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
+ return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
@@ -456,7 +462,7 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
- class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
+ class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
def reset_parameters(self):
self.bias = None
return None
@@ -468,8 +474,7 @@ class disable_weight_init:
weight = None
bias = None
offload_stream = None
- x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
- # x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
+ x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
@@ -845,8 +850,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
- def forward_comfy_cast_weights(self, input, compute_dtype=None):
- weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
+ def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
@@ -876,8 +881,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
-
- output = self.forward_comfy_cast_weights(input, compute_dtype)
+ output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
# Reshape output back to 3D if input was 3D
if reshaped_3d:
diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py
index 555542a46..ab7cf14fa 100644
--- a/comfy/rmsnorm.py
+++ b/comfy/rmsnorm.py
@@ -1,57 +1,10 @@
import torch
import comfy.model_management
-import numbers
-import logging
-
-RMSNorm = None
-
-try:
- rms_norm_torch = torch.nn.functional.rms_norm
- RMSNorm = torch.nn.RMSNorm
-except:
- rms_norm_torch = None
- logging.warning("Please update pytorch to use native RMSNorm")
+RMSNorm = torch.nn.RMSNorm
def rms_norm(x, weight=None, eps=1e-6):
- if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
- if weight is None:
- return rms_norm_torch(x, (x.shape[-1],), eps=eps)
- else:
- return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
+ if weight is None:
+ return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
else:
- r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
- if weight is None:
- return r
- else:
- return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
-
-
-if RMSNorm is None:
- class RMSNorm(torch.nn.Module):
- def __init__(
- self,
- normalized_shape,
- eps=1e-6,
- elementwise_affine=True,
- device=None,
- dtype=None,
- ):
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- if isinstance(normalized_shape, numbers.Integral):
- # mypy error: incompatible types in assignment
- normalized_shape = (normalized_shape,) # type: ignore[assignment]
- self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
- self.eps = eps
- self.elementwise_affine = elementwise_affine
- if self.elementwise_affine:
- self.weight = torch.nn.Parameter(
- torch.empty(self.normalized_shape, **factory_kwargs)
- )
- else:
- self.register_parameter("weight", None)
- self.bias = None
-
- def forward(self, x):
- return rms_norm(x, self.weight, self.eps)
+ return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py
index 9134e6d71..1f75f2ba7 100644
--- a/comfy/sampler_helpers.py
+++ b/comfy/sampler_helpers.py
@@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required
-def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
+def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
- return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
+ return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
-def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
+def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
- memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
- comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
+ if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
+ memory_required = 1e20
+ minimum_memory_required = None
+ else:
+ memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
+ memory_required += inference_memory
+ minimum_memory_required += inference_memory
+ comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
real_model = model.model
return real_model, conds, models
diff --git a/comfy/sd.py b/comfy/sd.py
index bc9407405..164f30803 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -423,6 +423,19 @@ class CLIP:
def get_key_patches(self):
return self.patcher.get_key_patches()
+ def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
+ self.cond_stage_model.reset_clip_options()
+
+ if self.layer_idx is not None:
+ self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
+
+ self.load_model()
+ self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
+ return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
+
+ def decode(self, token_ids, skip_special_tokens=True):
+ return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
+
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
@@ -793,8 +806,6 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
- model_management.archive_model_dtypes(self.first_stage_model)
-
if device is None:
device = model_management.vae_device()
self.device = device
@@ -803,6 +814,7 @@ class VAE:
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype)
+ model_management.archive_model_dtypes(self.first_stage_model)
self.output_device = model_management.intermediate_device()
mp = comfy.model_patcher.CoreModelPatcher
@@ -1183,6 +1195,7 @@ class TEModel(Enum):
JINA_CLIP_2 = 19
QWEN3_8B = 20
QWEN3_06B = 21
+ GEMMA_3_4B_VISION = 22
def detect_te_model(sd):
@@ -1211,7 +1224,10 @@ def detect_te_model(sd):
if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd:
- return TEModel.GEMMA_3_4B
+ if 'vision_model.embeddings.patch_embedding.weight' in sd:
+ return TEModel.GEMMA_3_4B_VISION
+ else:
+ return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
@@ -1271,6 +1287,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else:
if "text_projection" in clip_data[i]:
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
+ if "lm_head.weight" in clip_data[i]:
+ clip_data[i]["model.lm_head.weight"] = clip_data[i].pop("lm_head.weight") # prefix missing in some models
tokenizer_data = {}
clip_target = EmptyClass()
@@ -1336,6 +1354,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
+ elif te_model == TEModel.GEMMA_3_4B_VISION:
+ clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b_vision")
+ clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
+ tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
+ elif te_model == TEModel.GEMMA_3_12B:
+ clip_target.clip = comfy.text_encoders.lt.gemma3_te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.lt.Gemma3_12BTokenizer
+ tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index 4c817d468..d9d014055 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -171,8 +171,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None)
+ pad_token = self.special_tokens.get("pad", -1)
if end_token is None:
- cmp_token = self.special_tokens.get("pad", -1)
+ cmp_token = pad_token
else:
cmp_token = end_token
@@ -186,15 +187,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
other_embeds = []
eos = False
index = 0
+ left_pad = False
for y in x:
if isinstance(y, numbers.Integral):
- if eos:
+ token = int(y)
+ if index == 0 and token == pad_token:
+ left_pad = True
+
+ if eos or (left_pad and token == pad_token):
attention_mask.append(0)
else:
attention_mask.append(1)
- token = int(y)
+ left_pad = False
+
tokens_temp += [token]
- if not eos and token == cmp_token:
+ if not eos and token == cmp_token and not left_pad:
if end_token is None:
attention_mask[-1] = 0
eos = True
@@ -301,6 +308,15 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
+ def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[]):
+ if isinstance(tokens, dict):
+ tokens_only = next(iter(tokens.values())) # todo: get this better?
+ else:
+ tokens_only = tokens
+ tokens_only = [[t[0] for t in b] for b in tokens_only]
+ embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
+ return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens)
+
def parse_parentheses(string):
result = []
current_item = ""
@@ -656,6 +672,9 @@ class SDTokenizer:
def state_dict(self):
return {}
+ def decode(self, token_ids, skip_special_tokens=True):
+ return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
+
class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
if name is not None:
@@ -679,6 +698,9 @@ class SD1Tokenizer:
def state_dict(self):
return getattr(self, self.clip).state_dict()
+ def decode(self, token_ids, skip_special_tokens=True):
+ return getattr(self, self.clip).decode(token_ids, skip_special_tokens=skip_special_tokens)
+
class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
@@ -715,3 +737,6 @@ class SD1ClipModel(torch.nn.Module):
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)
+
+ def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
+ return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index 77264ed28..c28be1716 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -710,6 +710,15 @@ class Flux(supported_models_base.BASE):
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
+ def process_unet_state_dict(self, state_dict):
+ out_sd = {}
+ for k in list(state_dict.keys()):
+ key_out = k
+ if key_out.endswith("_norm.scale"):
+ key_out = "{}.weight".format(key_out[:-len(".scale")])
+ out_sd[key_out] = state_dict[k]
+ return out_sd
+
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
@@ -898,11 +907,13 @@ class HunyuanVideo(supported_models_base.BASE):
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
- key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
- key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
+ key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
+ key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
+ if key_out.endswith(".scale"):
+ key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
@@ -993,7 +1004,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
memory_usage_factor = 1.0
- supported_inference_dtypes = [torch.bfloat16, torch.float32]
+ supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
@@ -1023,11 +1034,7 @@ class Anima(supported_models_base.BASE):
memory_usage_factor = 1.0
- supported_inference_dtypes = [torch.bfloat16, torch.float32]
-
- def __init__(self, unet_config):
- super().__init__(unet_config)
- self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
+ supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Anima(self, device=device)
@@ -1038,6 +1045,12 @@ class Anima(supported_models_base.BASE):
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect))
+ def set_inference_dtype(self, dtype, manual_cast_dtype, **kwargs):
+ self.memory_usage_factor = (self.unet_config.get("model_channels", 2048) / 2048) * 0.95
+ if dtype is torch.float16:
+ self.memory_usage_factor *= 1.4
+ return super().set_inference_dtype(dtype, manual_cast_dtype, **kwargs)
+
class CosmosI2VPredict2(CosmosT2IPredict2):
unet_config = {
"image_model": "cosmos_predict2",
@@ -1262,6 +1275,15 @@ class Hunyuan3Dv2(supported_models_base.BASE):
latent_format = latent_formats.Hunyuan3Dv2
+ def process_unet_state_dict(self, state_dict):
+ out_sd = {}
+ for k in list(state_dict.keys()):
+ key_out = k
+ if key_out.endswith(".scale"):
+ key_out = "{}.weight".format(key_out[:-len(".scale")])
+ out_sd[key_out] = state_dict[k]
+ return out_sd
+
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
@@ -1339,6 +1361,14 @@ class Chroma(supported_models_base.BASE):
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
+ def process_unet_state_dict(self, state_dict):
+ out_sd = {}
+ for k in list(state_dict.keys()):
+ key_out = k
+ if key_out.endswith(".scale"):
+ key_out = "{}.weight".format(key_out[:-len(".scale")])
+ out_sd[key_out] = state_dict[k]
+ return out_sd
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Chroma(self, device=device)
diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py
index 00dd5ba90..f135d74c1 100644
--- a/comfy/text_encoders/ace15.py
+++ b/comfy/text_encoders/ace15.py
@@ -10,12 +10,12 @@ import comfy.utils
def sample_manual_loop_no_classes(
model,
ids=None,
- paddings=[],
execution_dtype=None,
cfg_scale: float = 2.0,
temperature: float = 0.85,
top_p: float = 0.9,
top_k: int = None,
+ min_p: float = 0.000,
seed: int = 1,
min_tokens: int = 1,
max_new_tokens: int = 2048,
@@ -23,6 +23,8 @@ def sample_manual_loop_no_classes(
audio_end_id: int = 215669,
eos_token_id: int = 151645,
):
+ if ids is None:
+ return []
device = model.execution_device
if execution_dtype is None:
@@ -32,31 +34,34 @@ def sample_manual_loop_no_classes(
execution_dtype = torch.float32
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
- for i, t in enumerate(paddings):
- attention_mask[i, :t] = 0
- attention_mask[i, t:] = 1
+ embeds_batch = embeds.shape[0]
output_audio_codes = []
past_key_values = []
generator = torch.Generator(device=device)
generator.manual_seed(seed)
model_config = model.transformer.model.config
+ past_kv_shape = [embeds_batch, model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim]
for x in range(model_config.num_hidden_layers):
- past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
+ past_key_values.append((torch.empty(past_kv_shape, device=device, dtype=execution_dtype), torch.empty(past_kv_shape, device=device, dtype=execution_dtype), 0))
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
- for step in range(max_new_tokens):
+ for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
past_key_values = outputs[2]
- cond_logits = next_token_logits[0:1]
- uncond_logits = next_token_logits[1:2]
- cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
+ if cfg_scale != 1.0:
+ cond_logits = next_token_logits[0:1]
+ uncond_logits = next_token_logits[1:2]
+ cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
+ else:
+ cfg_logits = next_token_logits[0:1]
- if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
+ use_eos_score = eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step
+ if use_eos_score:
eos_score = cfg_logits[:, eos_token_id].clone()
remove_logit_value = torch.finfo(cfg_logits.dtype).min
@@ -64,7 +69,7 @@ def sample_manual_loop_no_classes(
cfg_logits[:, :audio_start_id] = remove_logit_value
cfg_logits[:, audio_end_id:] = remove_logit_value
- if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
+ if use_eos_score:
cfg_logits[:, eos_token_id] = eos_score
if top_k is not None and top_k > 0:
@@ -72,6 +77,12 @@ def sample_manual_loop_no_classes(
min_val = top_k_vals[..., -1, None]
cfg_logits[cfg_logits < min_val] = remove_logit_value
+ if min_p is not None and min_p > 0:
+ probs = torch.softmax(cfg_logits, dim=-1)
+ p_max = probs.max(dim=-1, keepdim=True).values
+ indices_to_remove = probs < (min_p * p_max)
+ cfg_logits[indices_to_remove] = remove_logit_value
+
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
@@ -93,8 +104,8 @@ def sample_manual_loop_no_classes(
break
embed, _, _, _ = model.process_tokens([[token]], device)
- embeds = embed.repeat(2, 1, 1)
- attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
+ embeds = embed.repeat(embeds_batch, 1, 1)
+ attention_mask = torch.cat([attention_mask, torch.ones((embeds_batch, 1), device=device, dtype=attention_mask.dtype)], dim=1)
output_audio_codes.append(token - audio_start_id)
progress_bar.update_absolute(step)
@@ -102,24 +113,29 @@ def sample_manual_loop_no_classes(
return output_audio_codes
-def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
+def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0, min_p=0.000):
positive = [[token for token, _ in inner_list] for inner_list in positive]
- negative = [[token for token, _ in inner_list] for inner_list in negative]
positive = positive[0]
- negative = negative[0]
- neg_pad = 0
- if len(negative) < len(positive):
- neg_pad = (len(positive) - len(negative))
- negative = [model.special_tokens["pad"]] * neg_pad + negative
+ if cfg_scale != 1.0:
+ negative = [[token for token, _ in inner_list] for inner_list in negative]
+ negative = negative[0]
- pos_pad = 0
- if len(negative) > len(positive):
- pos_pad = (len(negative) - len(positive))
- positive = [model.special_tokens["pad"]] * pos_pad + positive
+ neg_pad = 0
+ if len(negative) < len(positive):
+ neg_pad = (len(positive) - len(negative))
+ negative = [model.special_tokens["pad"]] * neg_pad + negative
- paddings = [pos_pad, neg_pad]
- return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
+ pos_pad = 0
+ if len(negative) > len(positive):
+ pos_pad = (len(negative) - len(positive))
+ positive = [model.special_tokens["pad"]] * pos_pad + positive
+
+ ids = [positive, negative]
+ else:
+ ids = [positive]
+
+ return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
@@ -129,12 +145,12 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
user_metas = {
k: kwargs.pop(k)
- for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption")
+ for k in ("bpm", "duration", "keyscale", "timesignature")
if k in kwargs
}
timesignature = user_metas.get("timesignature")
if isinstance(timesignature, str) and timesignature.endswith("/4"):
- user_metas["timesignature"] = timesignature.rsplit("/", 1)[0]
+ user_metas["timesignature"] = timesignature[:-2]
user_metas = {
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
for k, v in user_metas.items()
@@ -147,8 +163,11 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
return f"\n{meta_yaml}\n" if not return_yaml else meta_yaml
def _metas_to_cap(self, **kwargs) -> str:
- use_keys = ("bpm", "duration", "keyscale", "timesignature")
+ use_keys = ("bpm", "timesignature", "keyscale", "duration")
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
+ timesignature = user_metas.get("timesignature")
+ if isinstance(timesignature, str) and timesignature.endswith("/4"):
+ user_metas["timesignature"] = timesignature[:-2]
duration = user_metas["duration"]
if duration == "N/A":
user_metas["duration"] = "30 seconds"
@@ -159,9 +178,13 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
- out = {}
+ text = text.strip()
+ text_negative = kwargs.get("caption_negative", text).strip()
lyrics = kwargs.get("lyrics", "")
+ lyrics_negative = kwargs.get("lyrics_negative", lyrics)
duration = kwargs.get("duration", 120)
+ if isinstance(duration, str):
+ duration = float(duration.split(None, 1)[0])
language = kwargs.get("language")
seed = kwargs.get("seed", 0)
@@ -170,28 +193,55 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
temperature = kwargs.get("temperature", 0.85)
top_p = kwargs.get("top_p", 0.9)
top_k = kwargs.get("top_k", 0.0)
-
+ min_p = kwargs.get("min_p", 0.000)
duration = math.ceil(duration)
kwargs["duration"] = duration
+ tokens_duration = duration * 5
+ min_tokens = int(kwargs.get("min_tokens", tokens_duration))
+ max_tokens = int(kwargs.get("max_tokens", tokens_duration))
- cot_text = self._metas_to_cot(caption = text, **kwargs)
+ metas_negative = {
+ k.rsplit("_", 1)[0]: kwargs.pop(k)
+ for k in ("bpm_negative", "duration_negative", "keyscale_negative", "timesignature_negative", "language_negative", "caption_negative")
+ if k in kwargs
+ }
+ if not kwargs.get("use_negative_caption"):
+ _ = metas_negative.pop("caption", None)
+
+ cot_text = self._metas_to_cot(caption=text, **kwargs)
+ cot_text_negative = "\n\n" if not metas_negative else self._metas_to_cot(**metas_negative)
meta_cap = self._metas_to_cap(**kwargs)
- lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n"
+ lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n"
+ lyrics_template = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>"
+ qwen3_06b_template = "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>"
- out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True)
- out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "\n"), disable_weights=True)
+ llm_prompts = {
+ "lm_prompt": lm_template.format(text, lyrics.strip(), cot_text),
+ "lm_prompt_negative": lm_template.format(text_negative, lyrics_negative.strip(), cot_text_negative),
+ "lyrics": lyrics_template.format(language if language is not None else "", lyrics),
+ "qwen3_06b": qwen3_06b_template.format(text, meta_cap),
+ }
- out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
- out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
- out["lm_metadata"] = {"min_tokens": duration * 5,
+ out = {
+ prompt_key: self.qwen3_06b.tokenize_with_weights(
+ prompt,
+ prompt_key == "qwen3_06b" and return_word_ids,
+ disable_weights = True,
+ **kwargs,
+ )
+ for prompt_key, prompt in llm_prompts.items()
+ }
+ out["lm_metadata"] = {"min_tokens": min_tokens,
+ "max_tokens": max_tokens,
"seed": seed,
"generate_audio_codes": generate_audio_codes,
"cfg_scale": cfg_scale,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
+ "min_p": min_p,
}
return out
@@ -252,7 +302,7 @@ class ACE15TEModel(torch.nn.Module):
lm_metadata = token_weight_pairs["lm_metadata"]
if lm_metadata["generate_audio_codes"]:
- audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
+ audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"], min_p=lm_metadata["min_p"])
out["audio_codes"] = [audio_codes]
return base_out, None, out
diff --git a/comfy/text_encoders/anima.py b/comfy/text_encoders/anima.py
index b6f58cb25..d8c5a6f92 100644
--- a/comfy/text_encoders/anima.py
+++ b/comfy/text_encoders/anima.py
@@ -23,7 +23,7 @@ class AnimaTokenizer:
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs)
- out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
+ out["qwen3_06b"] = [[(k[0], 1.0, k[2]) if return_word_ids else (k[0], 1.0) for k in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py
index b6735d210..e5d21fa74 100644
--- a/comfy/text_encoders/llama.py
+++ b/comfy/text_encoders/llama.py
@@ -3,6 +3,8 @@ import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any, Tuple
import math
+from tqdm import tqdm
+import comfy.utils
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
@@ -313,6 +315,13 @@ class Gemma3_4B_Config:
final_norm: bool = True
lm_head: bool = False
+GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
+
+@dataclass
+class Gemma3_4B_Vision_Config(Gemma3_4B_Config):
+ vision_config = GEMMA3_VISION_CONFIG
+ mm_tokens_per_image = 256
+
@dataclass
class Gemma3_12B_Config:
vocab_size: int = 262208
@@ -336,7 +345,7 @@ class Gemma3_12B_Config:
rope_scale = [8.0, 1.0]
final_norm: bool = True
lm_head: bool = False
- vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
+ vision_config = GEMMA3_VISION_CONFIG
mm_tokens_per_image = 256
class RMSNorm(nn.Module):
@@ -355,13 +364,6 @@ class RMSNorm(nn.Module):
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
-
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
if not isinstance(theta, list):
theta = [theta]
@@ -390,20 +392,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
- out.append((cos, sin))
+ sin_split = sin.shape[-1] // 2
+ out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
if len(out) == 1:
return out[0]
return out
-
def apply_rope(xq, xk, freqs_cis):
org_dtype = xq.dtype
cos = freqs_cis[0]
sin = freqs_cis[1]
- q_embed = (xq * cos) + (rotate_half(xq) * sin)
- k_embed = (xk * cos) + (rotate_half(xk) * sin)
+ nsin = freqs_cis[2]
+
+ q_embed = (xq * cos)
+ q_split = q_embed.shape[-1] // 2
+ q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
+ q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
+
+ k_embed = (xk * cos)
+ k_split = k_embed.shape[-1] // 2
+ k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
+ k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
+
return q_embed.to(org_dtype), k_embed.to(org_dtype)
@@ -438,8 +450,10 @@ class Attention(nn.Module):
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ sliding_window: Optional[int] = None,
):
batch_size, seq_length, _ = hidden_states.shape
+
xq = self.q_proj(hidden_states)
xk = self.k_proj(hidden_states)
xv = self.v_proj(hidden_states)
@@ -474,6 +488,11 @@ class Attention(nn.Module):
else:
present_key_value = (xk, xv, index + num_tokens)
+ if sliding_window is not None and xk.shape[2] > sliding_window:
+ xk = xk[:, :, -sliding_window:]
+ xv = xv[:, :, -sliding_window:]
+ attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
+
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
@@ -556,10 +575,12 @@ class TransformerBlockGemma2(nn.Module):
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
+ sliding_window = None
if self.transformer_type == 'gemma3':
if self.sliding_attention:
+ sliding_window = self.sliding_attention
if x.shape[1] > self.sliding_attention:
- sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
+ sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype)
sliding_mask.tril_(diagonal=-self.sliding_attention)
if attention_mask is not None:
attention_mask = attention_mask + sliding_mask
@@ -578,6 +599,7 @@ class TransformerBlockGemma2(nn.Module):
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
past_key_value=past_key_value,
+ sliding_window=sliding_window,
)
x = self.post_attention_layernorm(x)
@@ -762,6 +784,104 @@ class BaseLlama:
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)
+class BaseGenerate:
+ def logits(self, x):
+ input = x[:, -1:]
+ if hasattr(self.model, "lm_head"):
+ module = self.model.lm_head
+ else:
+ module = self.model.embed_tokens
+
+ offload_stream = None
+ if module.comfy_cast_weights:
+ weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
+ else:
+ weight = self.model.embed_tokens.weight.to(x)
+
+ x = torch.nn.functional.linear(input, weight, None)
+
+ comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
+ return x
+
+ def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=[], initial_tokens=[], execution_dtype=None, min_tokens=0):
+ device = embeds.device
+ model_config = self.model.config
+
+ if execution_dtype is None:
+ if comfy.model_management.should_use_bf16(device):
+ execution_dtype = torch.bfloat16
+ else:
+ execution_dtype = torch.float32
+ embeds = embeds.to(execution_dtype)
+
+ if embeds.ndim == 2:
+ embeds = embeds.unsqueeze(0)
+
+ past_key_values = [] #kv_cache init
+ max_cache_len = embeds.shape[1] + max_length
+ for x in range(model_config.num_hidden_layers):
+ past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
+ torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
+
+ generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
+
+ generated_token_ids = []
+ pbar = comfy.utils.ProgressBar(max_length)
+
+ # Generation loop
+ for step in tqdm(range(max_length), desc="Generating tokens"):
+ x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
+ logits = self.logits(x)[:, -1]
+ next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
+ token_id = next_token[0].item()
+ generated_token_ids.append(token_id)
+
+ embeds = self.model.embed_tokens(next_token).to(execution_dtype)
+ pbar.update(1)
+
+ if token_id in stop_tokens:
+ break
+
+ return generated_token_ids
+
+ def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
+
+ if not do_sample or temperature == 0.0:
+ return torch.argmax(logits, dim=-1, keepdim=True)
+
+ # Sampling mode
+ if repetition_penalty != 1.0:
+ for i in range(logits.shape[0]):
+ for token_id in set(token_history):
+ logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
+
+ if temperature != 1.0:
+ logits = logits / temperature
+
+ if top_k > 0:
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+ logits[indices_to_remove] = torch.finfo(logits.dtype).min
+
+ if min_p > 0.0:
+ probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)
+ top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True)
+ min_threshold = min_p * top_probs
+ indices_to_remove = probs_before_filter < min_threshold
+ logits[indices_to_remove] = torch.finfo(logits.dtype).min
+
+ if top_p < 1.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
+ sorted_indices_to_remove = cumulative_probs > top_p
+ sorted_indices_to_remove[..., 0] = False
+ indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
+ indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
+ logits[indices_to_remove] = torch.finfo(logits.dtype).min
+
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+
+ return torch.multinomial(probs, num_samples=1, generator=generator)
+
class BaseQwen3:
def logits(self, x):
input = x[:, -1:]
@@ -868,7 +988,7 @@ class Ovis25_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
-class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
+class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen25_7BVLI_Config(**config_dict)
@@ -878,6 +998,9 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
+ # todo: should this be tied or not?
+ #self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
+
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
@@ -920,7 +1043,7 @@ class Gemma2_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
-class Gemma3_4B(BaseLlama, torch.nn.Module):
+class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Config(**config_dict)
@@ -929,7 +1052,25 @@ class Gemma3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
-class Gemma3_12B(BaseLlama, torch.nn.Module):
+class Gemma3_4B_Vision(BaseLlama, BaseGenerate, torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ config = Gemma3_4B_Vision_Config(**config_dict)
+ self.num_layers = config.num_hidden_layers
+
+ self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
+ self.dtype = dtype
+ self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
+ self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
+ self.image_size = config.vision_config["image_size"]
+
+ def preprocess_embed(self, embed, device):
+ if embed["type"] == "image":
+ image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
+ return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
+ return None, None
+
+class Gemma3_12B(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_12B_Config(**config_dict)
diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py
index 26573fb12..82fbacf59 100644
--- a/comfy/text_encoders/lt.py
+++ b/comfy/text_encoders/lt.py
@@ -6,6 +6,7 @@ import comfy.text_encoders.genmo
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import torch
import comfy.utils
+import math
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -22,40 +23,79 @@ def ltxv_te(*args, **kwargs):
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
-class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
- def __init__(self, embedding_directory=None, tokenizer_data={}):
- tokenizer = tokenizer_data.get("spiece_model", None)
- super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
-
+class Gemma3_Tokenizer():
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
+ def tokenize_with_weights(self, text, return_word_ids=False, image=None, llama_template=None, skip_template=True, **kwargs):
+ self.llama_template = "system\nYou are a helpful assistant.\nuser\n{}\nmodel\n"
+ self.llama_template_images = "system\nYou are a helpful assistant.\nuser\n\n{}\n\nmodel\n"
+
+ if image is None:
+ images = []
+ else:
+ samples = image.movedim(-1, 1)
+ total = int(896 * 896)
+
+ scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
+ width = round(samples.shape[3] * scale_by)
+ height = round(samples.shape[2] * scale_by)
+
+ s = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
+ images = [s[:, :, :, :3]]
+
+ if text.startswith(''):
+ skip_template = True
+
+ if skip_template:
+ llama_text = text
+ else:
+ if llama_template is None:
+ if len(images) > 0:
+ llama_text = self.llama_template_images.format(text)
+ else:
+ llama_text = self.llama_template.format(text)
+ else:
+ llama_text = llama_template.format(text)
+
+ text_tokens = super().tokenize_with_weights(llama_text, return_word_ids)
+
+ if len(images) > 0:
+ embed_count = 0
+ for r in text_tokens:
+ for i, token in enumerate(r):
+ if token[0] == 262144 and embed_count < len(images):
+ r[i] = ({"type": "image", "data": images[embed_count]},) + token[1:]
+ embed_count += 1
+ return text_tokens
+
+class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ tokenizer = tokenizer_data.get("spiece_model", None)
+ special_tokens = {"": 262144, "": 106}
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
+
+
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
+
class Gemma3_12BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
-
+ self.dtypes = set()
+ self.dtypes.add(dtype)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
- def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs):
- text = llama_template.format(text)
- text_tokens = super().tokenize_with_weights(text, return_word_ids)
- embed_count = 0
- for k in text_tokens:
- tt = text_tokens[k]
- for r in tt:
- for i in range(len(r)):
- if r[i][0] == 262144:
- if image_embeds is not None and embed_count < image_embeds.shape[0]:
- r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
- embed_count += 1
- return text_tokens
+ def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
+ tokens_only = [[t[0] for t in b] for b in tokens]
+ embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
+ comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
+ return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is
class LTXAVTEModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
@@ -97,6 +137,7 @@ class LTXAVTEModel(torch.nn.Module):
token_weight_pairs = token_weight_pairs["gemma3_12b"]
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
+ out = out[:, :, -torch.sum(extra["attention_mask"]).item():]
out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
@@ -111,6 +152,9 @@ class LTXAVTEModel(torch.nn.Module):
return out.to(out_device), pooled
+ def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
+ return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
+
def load_sd(self, sd):
if "model.layers.47.self_attn.q_norm.weight" in sd:
return self.gemma3_12b.load_sd(sd)
@@ -138,6 +182,7 @@ class LTXAVTEModel(torch.nn.Module):
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
+ num_tokens = max(num_tokens, 64)
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
@@ -150,3 +195,14 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
dtype = dtype_llama
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return LTXAVTEModel_
+
+def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
+ class Gemma3_12BModel_(Gemma3_12BModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ if llama_quantization_metadata is not None:
+ model_options = model_options.copy()
+ model_options["llama_quantization_metadata"] = llama_quantization_metadata
+ if dtype_llama is not None:
+ dtype = dtype_llama
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
+ return Gemma3_12BModel_
diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py
index b29a7cc87..1b731e094 100644
--- a/comfy/text_encoders/lumina2.py
+++ b/comfy/text_encoders/lumina2.py
@@ -1,23 +1,23 @@
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.llama
-
+from comfy.text_encoders.lt import Gemma3_Tokenizer
+import comfy.utils
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
- super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
+ special_tokens = {"": 107}
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
-class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
+class Gemma3_4BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
- super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
-
- def state_dict(self):
- return {"spiece_model": self.tokenizer.serialize_model()}
+ special_tokens = {"": 262144, "": 106}
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, disable_weights=True, tokenizer_data=tokenizer_data)
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -31,6 +31,9 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+ def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
+ return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[107])
+
class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
@@ -40,6 +43,23 @@ class Gemma3_4BModel(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+ def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
+ return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106])
+
+class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
+ llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
+ if llama_quantization_metadata is not None:
+ model_options = model_options.copy()
+ model_options["quantization_metadata"] = llama_quantization_metadata
+
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+
+ def process_tokens(self, tokens, device):
+ embeds, _, _, embeds_info = super().process_tokens(tokens, device)
+ comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
+ return embeds
+
class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
@@ -50,6 +70,8 @@ def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b
model = Gemma2_2BModel
elif model_type == "gemma3_4b":
model = Gemma3_4BModel
+ elif model_type == "gemma3_4b_vision":
+ model = Gemma3_4B_Vision_Model
class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
diff --git a/comfy/text_encoders/spiece_tokenizer.py b/comfy/text_encoders/spiece_tokenizer.py
index caccb3ca2..099d8d2d9 100644
--- a/comfy/text_encoders/spiece_tokenizer.py
+++ b/comfy/text_encoders/spiece_tokenizer.py
@@ -6,9 +6,10 @@ class SPieceTokenizer:
def from_pretrained(path, **kwargs):
return SPieceTokenizer(path, **kwargs)
- def __init__(self, tokenizer_path, add_bos=False, add_eos=True):
+ def __init__(self, tokenizer_path, add_bos=False, add_eos=True, special_tokens=None):
self.add_bos = add_bos
self.add_eos = add_eos
+ self.special_tokens = special_tokens
import sentencepiece
if torch.is_tensor(tokenizer_path):
tokenizer_path = tokenizer_path.numpy().tobytes()
@@ -27,8 +28,32 @@ class SPieceTokenizer:
return out
def __call__(self, string):
+ if self.special_tokens is not None:
+ import re
+ special_tokens_pattern = '|'.join(re.escape(token) for token in self.special_tokens.keys())
+ if special_tokens_pattern and re.search(special_tokens_pattern, string):
+ parts = re.split(f'({special_tokens_pattern})', string)
+ result = []
+ for part in parts:
+ if not part:
+ continue
+ if part in self.special_tokens:
+ result.append(self.special_tokens[part])
+ else:
+ encoded = self.tokenizer.encode(part, add_bos=False, add_eos=False)
+ result.extend(encoded)
+ return {"input_ids": result}
+
out = self.tokenizer.encode(string)
return {"input_ids": out}
+ def decode(self, token_ids, skip_special_tokens=False):
+
+ if skip_special_tokens and self.special_tokens:
+ special_token_ids = set(self.special_tokens.values())
+ token_ids = [tid for tid in token_ids if tid not in special_token_ids]
+
+ return self.tokenizer.decode(token_ids)
+
def serialize_model(self):
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))
diff --git a/comfy/utils.py b/comfy/utils.py
index 1337e2205..17443b4cc 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -20,13 +20,14 @@
import torch
import math
import struct
-import comfy.checkpoint_pickle
+import comfy.memory_management
import safetensors.torch
import numpy as np
from PIL import Image
import logging
import itertools
from torch.nn.functional import interpolate
+from tqdm.auto import trange
from einops import rearrange
from comfy.cli_args import args, enables_dynamic_vram
import json
@@ -37,26 +38,26 @@ import warnings
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
-ALWAYS_SAFE_LOAD = False
-if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
+
+if True: # ckpt/pt file whitelist for safe loading of old sd files
class ModelCheckpoint:
pass
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
def scalar(*args, **kwargs):
- from numpy.core.multiarray import scalar as sc
- return sc(*args, **kwargs)
+ return None
scalar.__module__ = "numpy.core.multiarray"
from numpy import dtype
from numpy.dtypes import Float64DType
- from _codecs import encode
+
+ def encode(*args, **kwargs): # no longer necessary on newer torch
+ return None
+ encode.__module__ = "_codecs"
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
- ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")
-else:
- logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
+
# Current as of safetensors 0.7.0
_TYPES = {
@@ -139,11 +140,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if MMAP_TORCH_FILES:
torch_args["mmap"] = True
- if safe_load or ALWAYS_SAFE_LOAD:
- pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
- else:
- logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
- pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
+ pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
+
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
@@ -674,10 +672,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"ff_context.linear_in.bias": "txt_mlp.0.bias",
"ff_context.linear_out.weight": "txt_mlp.2.weight",
"ff_context.linear_out.bias": "txt_mlp.2.bias",
- "attn.norm_q.weight": "img_attn.norm.query_norm.scale",
- "attn.norm_k.weight": "img_attn.norm.key_norm.scale",
- "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
- "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
+ "attn.norm_q.weight": "img_attn.norm.query_norm.weight",
+ "attn.norm_k.weight": "img_attn.norm.key_norm.weight",
+ "attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
+ "attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
}
for k in block_map:
@@ -700,8 +698,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
- "attn.norm_q.weight": "norm.query_norm.scale",
- "attn.norm_k.weight": "norm.key_norm.scale",
+ "attn.norm_q.weight": "norm.query_norm.weight",
+ "attn.norm_k.weight": "norm.key_norm.weight",
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
"attn.to_out.weight": "linear2.weight", # Flux 2
}
@@ -1155,6 +1153,32 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
+def model_trange(*args, **kwargs):
+ if comfy.memory_management.aimdo_allocator is None:
+ return trange(*args, **kwargs)
+
+ pbar = trange(*args, **kwargs, smoothing=1.0)
+ pbar._i = 0
+ pbar.set_postfix_str(" Model Initializing ... ")
+
+ _update = pbar.update
+
+ def warmup_update(n=1):
+ pbar._i += 1
+ if pbar._i == 1:
+ pbar.i1_time = time.time()
+ pbar.set_postfix_str(" Model Initialization complete! ")
+ elif pbar._i == 2:
+ #bring forward the effective start time based the the diff between first and second iteration
+ #to attempt to remove load overhead from the final step rate estimate.
+ pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
+ pbar.set_postfix_str("")
+
+ _update(n)
+
+ pbar.update = warmup_update
+ return pbar
+
PROGRESS_BAR_ENABLED = True
def set_progress_bar_enabled(enabled):
global PROGRESS_BAR_ENABLED
@@ -1376,3 +1400,29 @@ def string_to_seed(data):
else:
crc >>= 1
return crc ^ 0xFFFFFFFF
+
+def deepcopy_list_dict(obj, memo=None):
+ if memo is None:
+ memo = {}
+
+ obj_id = id(obj)
+ if obj_id in memo:
+ return memo[obj_id]
+
+ if isinstance(obj, dict):
+ res = {deepcopy_list_dict(k, memo): deepcopy_list_dict(v, memo) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ res = [deepcopy_list_dict(i, memo) for i in obj]
+ else:
+ res = obj
+
+ memo[obj_id] = res
+ return res
+
+def normalize_image_embeddings(embeds, embeds_info, scale_factor):
+ """Normalize image embeddings to match text embedding scale"""
+ for info in embeds_info:
+ if info.get("type") == "image":
+ start_idx = info["index"]
+ end_idx = start_idx + info["size"]
+ embeds[:, start_idx:end_idx, :] /= scale_factor
diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py
index bce89a0e2..d352e066b 100644
--- a/comfy/weight_adapter/base.py
+++ b/comfy/weight_adapter/base.py
@@ -49,6 +49,12 @@ class WeightAdapterBase:
"""
raise NotImplementedError
+ def calculate_shape(
+ self,
+ key
+ ):
+ return None
+
def calculate_weight(
self,
weight,
diff --git a/comfy/weight_adapter/bypass.py b/comfy/weight_adapter/bypass.py
index d4aaf98ca..b9d5ec7d9 100644
--- a/comfy/weight_adapter/bypass.py
+++ b/comfy/weight_adapter/bypass.py
@@ -21,6 +21,7 @@ from typing import Optional, Union
import torch
import torch.nn as nn
+import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase
from comfy.patcher_extension import PatcherInjection
@@ -181,18 +182,21 @@ class BypassForwardHook:
)
return # Already injected
- # Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
- device = None
+ # Move adapter weights to compute device (GPU)
+ # Use get_torch_device() instead of module.weight.device because
+ # with offloading, module weights may be on CPU while compute happens on GPU
+ device = comfy.model_management.get_torch_device()
+
+ # Get dtype from module weight if available
dtype = None
if hasattr(self.module, "weight") and self.module.weight is not None:
- device = self.module.weight.device
dtype = self.module.weight.dtype
- elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
- device = self.module.W_q.device
- dtype = self.module.W_q.dtype
- if device is not None:
- self._move_adapter_weights_to_device(device, dtype)
+ # Only use dtype if it's a standard float type, not quantized
+ if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
+ dtype = None
+
+ self._move_adapter_weights_to_device(device, dtype)
self.original_forward = self.module.forward
self.module.forward = self._bypass_forward
diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py
index bc4260a8f..8e1261a12 100644
--- a/comfy/weight_adapter/lora.py
+++ b/comfy/weight_adapter/lora.py
@@ -214,6 +214,13 @@ class LoRAAdapter(WeightAdapterBase):
else:
return None
+ def calculate_shape(
+ self,
+ key
+ ):
+ reshape = self.weights[5]
+ return tuple(reshape) if reshape is not None else None
+
def calculate_weight(
self,
weight,
diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py
index de167f037..a90a5ca40 100644
--- a/comfy_api/feature_flags.py
+++ b/comfy_api/feature_flags.py
@@ -14,6 +14,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
+ "node_replacements": True,
}
diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py
index 8542a1dbc..f2399422b 100644
--- a/comfy_api/latest/__init__.py
+++ b/comfy_api/latest/__init__.py
@@ -21,6 +21,17 @@ class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest"
STABLE = False
+ def __init__(self):
+ super().__init__()
+ self.node_replacement = self.NodeReplacement()
+ self.execution = self.Execution()
+
+ class NodeReplacement(ProxiedSingleton):
+ async def register(self, node_replace: io.NodeReplace) -> None:
+ """Register a node replacement mapping."""
+ from server import PromptServer
+ PromptServer.instance.node_replace_manager.register(node_replace)
+
class Execution(ProxiedSingleton):
async def set_progress(
self,
@@ -73,8 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
- execution: Execution
-
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py
index e634a0311..451e9526e 100644
--- a/comfy_api/latest/_input/video_types.py
+++ b/comfy_api/latest/_input/video_types.py
@@ -34,6 +34,21 @@ class VideoInput(ABC):
"""
pass
+ @abstractmethod
+ def as_trimmed(
+ self,
+ start_time: float | None = None,
+ duration: float | None = None,
+ strict_duration: bool = False,
+ ) -> VideoInput | None:
+ """
+ Create a new VideoInput which is trimmed to have the corresponding start_time and duration
+
+ Returns:
+ A new VideoInput, or None if the result would have negative duration
+ """
+ pass
+
def get_stream_source(self) -> Union[str, io.BytesIO]:
"""
Get a streamable source for the video. This allows processing without
diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py
index 1405d0b81..3463ed1c9 100644
--- a/comfy_api/latest/_input_impl/video_types.py
+++ b/comfy_api/latest/_input_impl/video_types.py
@@ -6,6 +6,7 @@ from typing import Optional
from .._input import AudioInput, VideoInput
import av
import io
+import itertools
import json
import numpy as np
import math
@@ -29,7 +30,6 @@ def container_to_output_format(container_format: str | None) -> str | None:
formats = container_format.split(",")
return formats[0]
-
def get_open_write_kwargs(
dest: str | io.BytesIO, container_format: str, to_format: str | None
) -> dict:
@@ -57,12 +57,14 @@ class VideoFromFile(VideoInput):
Class representing video input from a file.
"""
- def __init__(self, file: str | io.BytesIO):
+ def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
+ self.__start_time = start_time
+ self.__duration = duration
def get_stream_source(self) -> str | io.BytesIO:
"""
@@ -96,6 +98,16 @@ class VideoFromFile(VideoInput):
Returns:
Duration in seconds
"""
+ raw_duration = self._get_raw_duration()
+ if self.__start_time < 0:
+ duration_from_start = min(raw_duration, -self.__start_time)
+ else:
+ duration_from_start = raw_duration - self.__start_time
+ if self.__duration:
+ return min(self.__duration, duration_from_start)
+ return duration_from_start
+
+ def _get_raw_duration(self) -> float:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
@@ -113,9 +125,13 @@ class VideoFromFile(VideoInput):
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
- for packet in container.demux(video_stream):
- for _ in packet.decode():
- frame_count += 1
+ frame_iterator = (
+ container.decode(video_stream)
+ if video_stream.codec.capabilities & 0x100
+ else container.demux(video_stream)
+ )
+ for packet in frame_iterator:
+ frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
@@ -131,36 +147,54 @@ class VideoFromFile(VideoInput):
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
- # 1. Prefer the frames field if available
- if video_stream.frames and video_stream.frames > 0:
+ # 1. Prefer the frames field if available and usable
+ if (
+ video_stream.frames
+ and video_stream.frames > 0
+ and not self.__start_time
+ and not self.__duration
+ ):
return int(video_stream.frames)
# 2. Try to estimate from duration and average_rate using only metadata
- if container.duration is not None and video_stream.average_rate:
- duration_seconds = float(container.duration / av.time_base)
- estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
- if estimated_frames > 0:
- return estimated_frames
-
if (
getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate
):
- duration_seconds = float(video_stream.duration * video_stream.time_base)
+ raw_duration = float(video_stream.duration * video_stream.time_base)
+ if self.__start_time < 0:
+ duration_from_start = min(raw_duration, -self.__start_time)
+ else:
+ duration_from_start = raw_duration - self.__start_time
+ duration_seconds = min(self.__duration, duration_from_start)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
# 3. Last resort: decode frames and count them (streaming)
- frame_count = 0
- container.seek(0)
- for packet in container.demux(video_stream):
- for _ in packet.decode():
- frame_count += 1
-
- if frame_count == 0:
- raise ValueError(f"Could not determine frame count for file '{self.__file}'")
+ if self.__start_time < 0:
+ start_time = max(self._get_raw_duration() + self.__start_time, 0)
+ else:
+ start_time = self.__start_time
+ frame_count = 1
+ start_pts = int(start_time / video_stream.time_base)
+ end_pts = int((start_time + self.__duration) / video_stream.time_base)
+ container.seek(start_pts, stream=video_stream)
+ frame_iterator = (
+ container.decode(video_stream)
+ if video_stream.codec.capabilities & 0x100
+ else container.demux(video_stream)
+ )
+ for frame in frame_iterator:
+ if frame.pts >= start_pts:
+ break
+ else:
+ raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
+ for frame in frame_iterator:
+ if frame.pts >= end_pts:
+ break
+ frame_count += 1
return frame_count
def get_frame_rate(self) -> Fraction:
@@ -199,9 +233,21 @@ class VideoFromFile(VideoInput):
return container.format.name
def get_components_internal(self, container: InputContainer) -> VideoComponents:
+ video_stream = self._get_first_video_stream(container)
+ if self.__start_time < 0:
+ start_time = max(self._get_raw_duration() + self.__start_time, 0)
+ else:
+ start_time = self.__start_time
# Get video frames
frames = []
- for frame in container.decode(video=0):
+ start_pts = int(start_time / video_stream.time_base)
+ end_pts = int((start_time + self.__duration) / video_stream.time_base)
+ container.seek(start_pts, stream=video_stream)
+ for frame in container.decode(video_stream):
+ if frame.pts < start_pts:
+ continue
+ if self.__duration and frame.pts >= end_pts:
+ break
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
@@ -209,31 +255,44 @@ class VideoFromFile(VideoInput):
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
# Get frame rate
- video_stream = next(s for s in container.streams if s.type == 'video')
- frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
+ frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
# Get audio if available
audio = None
- try:
- container.seek(0) # Reset the container to the beginning
- for stream in container.streams:
- if stream.type != 'audio':
- continue
- assert isinstance(stream, av.AudioStream)
- audio_frames = []
- for packet in container.demux(stream):
- for frame in packet.decode():
- assert isinstance(frame, av.AudioFrame)
- audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
- if len(audio_frames) > 0:
- audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
- audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
- audio = AudioInput({
- "waveform": audio_tensor,
- "sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
- })
- except StopIteration:
- pass # No audio stream
+ container.seek(start_pts, stream=video_stream)
+ # Use last stream for consistency
+ if len(container.streams.audio):
+ audio_stream = container.streams.audio[-1]
+ audio_frames = []
+ resample = av.audio.resampler.AudioResampler(format='fltp').resample
+ frames = itertools.chain.from_iterable(
+ map(resample, container.decode(audio_stream))
+ )
+
+ has_first_frame = False
+ for frame in frames:
+ offset_seconds = start_time - frame.pts * audio_stream.time_base
+ to_skip = int(offset_seconds * audio_stream.sample_rate)
+ if to_skip < frame.samples:
+ has_first_frame = True
+ break
+ if has_first_frame:
+ audio_frames.append(frame.to_ndarray()[..., to_skip:])
+
+ for frame in frames:
+ if frame.time > start_time + self.__duration:
+ break
+ audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
+ if len(audio_frames) > 0:
+ audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
+ if self.__duration:
+ audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
+
+ audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
+ audio = AudioInput({
+ "waveform": audio_tensor,
+ "sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
+ })
metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
@@ -250,7 +309,7 @@ class VideoFromFile(VideoInput):
path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
- metadata: Optional[dict] = None
+ metadata: Optional[dict] = None,
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
@@ -262,15 +321,14 @@ class VideoFromFile(VideoInput):
reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False
+ if self.__start_time or self.__duration:
+ reuse_streams = False
if not reuse_streams:
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
- path,
- format=format,
- codec=codec,
- metadata=metadata
+ path, format=format, codec=codec, metadata=metadata
)
streams = container.streams
@@ -304,10 +362,21 @@ class VideoFromFile(VideoInput):
output_container.mux(packet)
def _get_first_video_stream(self, container: InputContainer):
- video_stream = next((s for s in container.streams if s.type == "video"), None)
- if video_stream is None:
- raise ValueError(f"No video stream found in file '{self.__file}'")
- return video_stream
+ if len(container.streams.video):
+ return container.streams.video[0]
+ raise ValueError(f"No video stream found in file '{self.__file}'")
+
+ def as_trimmed(
+ self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
+ ) -> VideoInput | None:
+ trimmed = VideoFromFile(
+ self.get_stream_source(),
+ start_time=start_time + self.__start_time,
+ duration=duration,
+ )
+ if trimmed.get_duration() < duration and strict_duration:
+ return None
+ return trimmed
class VideoFromComponents(VideoInput):
@@ -322,7 +391,7 @@ class VideoFromComponents(VideoInput):
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
- frame_rate=self.__components.frame_rate
+ frame_rate=self.__components.frame_rate,
)
def save_to(
@@ -330,7 +399,7 @@ class VideoFromComponents(VideoInput):
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
- metadata: Optional[dict] = None
+ metadata: Optional[dict] = None,
):
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
@@ -357,7 +426,10 @@ class VideoFromComponents(VideoInput):
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
- audio_stream = output.add_stream('aac', rate=audio_sample_rate)
+ waveform = self.__components.audio['waveform']
+ waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
+ layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
+ audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
# Encode video
for i, frame in enumerate(self.__components.images):
@@ -372,12 +444,21 @@ class VideoFromComponents(VideoInput):
output.mux(packet)
if audio_stream and self.__components.audio:
- waveform = self.__components.audio['waveform']
- waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
- frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
+ frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout)
frame.sample_rate = audio_sample_rate
frame.pts = 0
output.mux(audio_stream.encode(frame))
# Flush encoder
output.mux(audio_stream.encode(None))
+
+ def as_trimmed(
+ self,
+ start_time: float | None = None,
+ duration: float | None = None,
+ strict_duration: bool = True,
+ ) -> VideoInput | None:
+ if self.get_duration() < start_time + duration:
+ return None
+ #TODO Consider tracking duration and trimming at time of save?
+ return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index 93cf482ca..dee487c92 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -75,6 +75,12 @@ class NumberDisplay(str, Enum):
slider = "slider"
+class ControlAfterGenerate(str, Enum):
+ fixed = "fixed"
+ increment = "increment"
+ decrement = "decrement"
+ randomize = "randomize"
+
class _ComfyType(ABC):
Type = Any
io_type: str = None
@@ -263,7 +269,7 @@ class Int(ComfyTypeIO):
class Input(WidgetInput):
'''Integer input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
- default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
+ default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min
@@ -345,7 +351,7 @@ class Combo(ComfyTypeIO):
tooltip: str=None,
lazy: bool=None,
default: str | int | Enum = None,
- control_after_generate: bool=None,
+ control_after_generate: bool | ControlAfterGenerate=None,
upload: UploadType=None,
image_folder: FolderType=None,
remote: RemoteOptions=None,
@@ -389,7 +395,7 @@ class MultiCombo(ComfyTypeI):
Type = list[str]
class Input(Combo.Input):
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
- default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
+ default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool | ControlAfterGenerate=None,
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
self.multiselect = True
@@ -1203,6 +1209,30 @@ class Color(ComfyTypeIO):
def as_dict(self):
return super().as_dict()
+@comfytype(io_type="BOUNDING_BOX")
+class BoundingBox(ComfyTypeIO):
+ class BoundingBoxDict(TypedDict):
+ x: int
+ y: int
+ width: int
+ height: int
+ Type = BoundingBoxDict
+
+ class Input(WidgetInput):
+ def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
+ socketless: bool=True, default: dict=None, component: str=None):
+ super().__init__(id, display_name, optional, tooltip, None, default, socketless)
+ self.component = component
+ if default is None:
+ self.default = {"x": 0, "y": 0, "width": 512, "height": 512}
+
+ def as_dict(self):
+ d = super().as_dict()
+ if self.component:
+ d["component"] = self.component
+ return d
+
+
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
@@ -1309,6 +1339,7 @@ class NodeInfoV1:
api_node: bool=None
price_badge: dict | None = None
search_aliases: list[str]=None
+ essentials_category: str=None
@dataclass
@@ -1430,6 +1461,8 @@ class Schema:
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
accept_all_inputs: bool=False
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
+ essentials_category: str | None = None
+ """Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing')."""
def validate(self):
'''Validate the schema:
@@ -1536,6 +1569,7 @@ class Schema:
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
search_aliases=self.search_aliases if self.search_aliases else None,
+ essentials_category=self.essentials_category,
)
return info
@@ -2030,11 +2064,74 @@ class _UIOutput(ABC):
...
+class InputMapOldId(TypedDict):
+ """Map an old node input to a new node input by ID."""
+ new_id: str
+ old_id: str
+
+class InputMapSetValue(TypedDict):
+ """Set a specific value for a new node input."""
+ new_id: str
+ set_value: Any
+
+InputMap = InputMapOldId | InputMapSetValue
+"""
+Input mapping for node replacement. Type is inferred by dictionary keys:
+- {"new_id": str, "old_id": str} - maps old input to new input
+- {"new_id": str, "set_value": Any} - sets a specific value for new input
+"""
+
+class OutputMap(TypedDict):
+ """Map outputs of node replacement via indexes."""
+ new_idx: int
+ old_idx: int
+
+class NodeReplace:
+ """
+ Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
+
+ Also supports assigning specific values to the input widgets of the new node.
+
+ Args:
+ new_node_id: The class name of the new replacement node.
+ old_node_id: The class name of the deprecated node.
+ old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
+ connected. The workflow JSON stores widget values by their relative position index,
+ not by ID. This list maps those positional indexes to input IDs, enabling the
+ replacement system to correctly identify widget values during node migration.
+ input_mapping: List of input mappings from old node to new node.
+ output_mapping: List of output mappings from old node to new node.
+ """
+ def __init__(self,
+ new_node_id: str,
+ old_node_id: str,
+ old_widget_ids: list[str] | None=None,
+ input_mapping: list[InputMap] | None=None,
+ output_mapping: list[OutputMap] | None=None,
+ ):
+ self.new_node_id = new_node_id
+ self.old_node_id = old_node_id
+ self.old_widget_ids = old_widget_ids
+ self.input_mapping = input_mapping
+ self.output_mapping = output_mapping
+
+ def as_dict(self):
+ """Create serializable representation of the node replacement."""
+ return {
+ "new_node_id": self.new_node_id,
+ "old_node_id": self.old_node_id,
+ "old_widget_ids": self.old_widget_ids,
+ "input_mapping": list(self.input_mapping) if self.input_mapping else None,
+ "output_mapping": list(self.output_mapping) if self.output_mapping else None,
+ }
+
+
__all__ = [
"FolderType",
"UploadType",
"RemoteOptions",
"NumberDisplay",
+ "ControlAfterGenerate",
"comfytype",
"Custom",
@@ -2121,4 +2218,6 @@ __all__ = [
"ImageCompare",
"PriceBadgeDepends",
"PriceBadge",
+ "BoundingBox",
+ "NodeReplace",
]
diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py
index ee2aa1ce6..46a583b5e 100644
--- a/comfy_api_nodes/apis/__init__.py
+++ b/comfy_api_nodes/apis/__init__.py
@@ -1197,12 +1197,6 @@ class KlingImageGenImageReferenceType(str, Enum):
face = 'face'
-class KlingImageGenModelName(str, Enum):
- kling_v1 = 'kling-v1'
- kling_v1_5 = 'kling-v1-5'
- kling_v2 = 'kling-v2'
-
-
class KlingImageGenerationsRequest(BaseModel):
aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9'
callback_url: Optional[AnyUrl] = Field(
@@ -1218,7 +1212,7 @@ class KlingImageGenerationsRequest(BaseModel):
0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0
)
image_reference: Optional[KlingImageGenImageReferenceType] = None
- model_name: Optional[KlingImageGenModelName] = 'kling-v1'
+ model_name: str = Field(...)
n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9)
negative_prompt: Optional[str] = Field(
None, description='Negative text prompt', max_length=200
diff --git a/comfy_api_nodes/apis/bria.py b/comfy_api_nodes/apis/bria.py
index 9119cacc6..8c496b56c 100644
--- a/comfy_api_nodes/apis/bria.py
+++ b/comfy_api_nodes/apis/bria.py
@@ -45,17 +45,55 @@ class BriaEditImageRequest(BaseModel):
)
+class BriaRemoveBackgroundRequest(BaseModel):
+ image: str = Field(...)
+ sync: bool = Field(False)
+ visual_input_content_moderation: bool = Field(
+ False, description="If true, returns 422 on input image moderation failure."
+ )
+ visual_output_content_moderation: bool = Field(
+ False, description="If true, returns 422 on visual output moderation failure."
+ )
+ seed: int = Field(...)
+
+
class BriaStatusResponse(BaseModel):
request_id: str = Field(...)
status_url: str = Field(...)
warning: str | None = Field(None)
-class BriaResult(BaseModel):
+class BriaRemoveBackgroundResult(BaseModel):
+ image_url: str = Field(...)
+
+
+class BriaRemoveBackgroundResponse(BaseModel):
+ status: str = Field(...)
+ result: BriaRemoveBackgroundResult | None = Field(None)
+
+
+class BriaImageEditResult(BaseModel):
structured_prompt: str = Field(...)
image_url: str = Field(...)
-class BriaResponse(BaseModel):
+class BriaImageEditResponse(BaseModel):
status: str = Field(...)
- result: BriaResult | None = Field(None)
+ result: BriaImageEditResult | None = Field(None)
+
+
+class BriaRemoveVideoBackgroundRequest(BaseModel):
+ video: str = Field(...)
+ background_color: str = Field(default="transparent", description="Background color for the output video.")
+ output_container_and_codec: str = Field(...)
+ preserve_audio: bool = Field(True)
+ seed: int = Field(...)
+
+
+class BriaRemoveVideoBackgroundResult(BaseModel):
+ video_url: str = Field(...)
+
+
+class BriaRemoveVideoBackgroundResponse(BaseModel):
+ status: str = Field(...)
+ result: BriaRemoveVideoBackgroundResult | None = Field(None)
diff --git a/comfy_api_nodes/apis/gemini.py b/comfy_api_nodes/apis/gemini.py
index d81337dae..3304d7e76 100644
--- a/comfy_api_nodes/apis/gemini.py
+++ b/comfy_api_nodes/apis/gemini.py
@@ -116,9 +116,15 @@ class GeminiGenerationConfig(BaseModel):
topP: float | None = Field(None, ge=0.0, le=1.0)
+class GeminiImageOutputOptions(BaseModel):
+ mimeType: str = Field("image/png")
+ compressionQuality: int | None = Field(None)
+
+
class GeminiImageConfig(BaseModel):
aspectRatio: str | None = Field(None)
imageSize: str | None = Field(None)
+ imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
class GeminiImageGenerationConfig(GeminiGenerationConfig):
diff --git a/comfy_api_nodes/apis/hunyuan3d.py b/comfy_api_nodes/apis/hunyuan3d.py
index 6421c9bd5..e84eba31e 100644
--- a/comfy_api_nodes/apis/hunyuan3d.py
+++ b/comfy_api_nodes/apis/hunyuan3d.py
@@ -64,3 +64,23 @@ class To3DProTaskResultResponse(BaseModel):
class To3DProTaskQueryRequest(BaseModel):
JobId: str = Field(...)
+
+
+class To3DUVFileInput(BaseModel):
+ Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
+ Url: str = Field(...)
+
+
+class To3DUVTaskRequest(BaseModel):
+ File: To3DUVFileInput = Field(...)
+
+
+class TextureEditImageInfo(BaseModel):
+ Url: str = Field(...)
+
+
+class TextureEditTaskRequest(BaseModel):
+ File3D: To3DUVFileInput = Field(...)
+ Image: TextureEditImageInfo | None = Field(None)
+ Prompt: str | None = Field(None)
+ EnablePBR: bool | None = Field(None)
diff --git a/comfy_api_nodes/apis/kling.py b/comfy_api_nodes/apis/kling.py
index bf54ede3e..9c0446075 100644
--- a/comfy_api_nodes/apis/kling.py
+++ b/comfy_api_nodes/apis/kling.py
@@ -1,12 +1,22 @@
from pydantic import BaseModel, Field
+class MultiPromptEntry(BaseModel):
+ index: int = Field(...)
+ prompt: str = Field(...)
+ duration: str = Field(...)
+
+
class OmniProText2VideoRequest(BaseModel):
model_name: str = Field(..., description="kling-video-o1")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
+ multi_shot: bool | None = Field(None)
+ multi_prompt: list[MultiPromptEntry] | None = Field(None)
+ shot_type: str | None = Field(None)
+ sound: str = Field(..., description="'on' or 'off'")
class OmniParamImage(BaseModel):
@@ -26,6 +36,10 @@ class OmniProFirstLastFrameRequest(BaseModel):
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
+ sound: str | None = Field(None, description="'on' or 'off'")
+ multi_shot: bool | None = Field(None)
+ multi_prompt: list[MultiPromptEntry] | None = Field(None)
+ shot_type: str | None = Field(None)
class OmniProReferences2VideoRequest(BaseModel):
@@ -38,6 +52,10 @@ class OmniProReferences2VideoRequest(BaseModel):
duration: str | None = Field(..., description="From 3 to 10.")
prompt: str = Field(...)
mode: str = Field("pro")
+ sound: str | None = Field(None, description="'on' or 'off'")
+ multi_shot: bool | None = Field(None)
+ multi_prompt: list[MultiPromptEntry] | None = Field(None)
+ shot_type: str | None = Field(None)
class TaskStatusVideoResult(BaseModel):
@@ -54,6 +72,7 @@ class TaskStatusImageResult(BaseModel):
class TaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None)
+ series_images: list[TaskStatusImageResult] | None = Field(None)
class TaskStatusResponseData(BaseModel):
@@ -77,31 +96,42 @@ class OmniImageParamImage(BaseModel):
class OmniProImageRequest(BaseModel):
- model_name: str = Field(..., description="kling-image-o1")
- resolution: str = Field(..., description="'1k' or '2k'")
+ model_name: str = Field(...)
+ resolution: str = Field(...)
aspect_ratio: str | None = Field(...)
prompt: str = Field(...)
mode: str = Field("pro")
n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
+ result_type: str | None = Field(None, description="Set to 'series' for series generation")
+ series_amount: int | None = Field(None, ge=2, le=9, description="Number of images in a series")
class TextToVideoWithAudioRequest(BaseModel):
- model_name: str = Field(..., description="kling-v2-6")
+ model_name: str = Field(...)
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
- duration: str = Field(..., description="'5' or '10'")
- prompt: str = Field(...)
+ duration: str = Field(...)
+ prompt: str | None = Field(...)
+ negative_prompt: str | None = Field(None)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
+ multi_shot: bool | None = Field(None)
+ multi_prompt: list[MultiPromptEntry] | None = Field(None)
+ shot_type: str | None = Field(None)
class ImageToVideoWithAudioRequest(BaseModel):
- model_name: str = Field(..., description="kling-v2-6")
+ model_name: str = Field(...)
image: str = Field(...)
- duration: str = Field(..., description="'5' or '10'")
- prompt: str = Field(...)
+ image_tail: str | None = Field(None)
+ duration: str = Field(...)
+ prompt: str | None = Field(...)
+ negative_prompt: str | None = Field(None)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
+ multi_shot: bool | None = Field(None)
+ multi_prompt: list[MultiPromptEntry] | None = Field(None)
+ shot_type: str | None = Field(None)
class MotionControlRequest(BaseModel):
diff --git a/comfy_api_nodes/apis/recraft.py b/comfy_api_nodes/apis/recraft.py
index 0bd7d23b3..78ededd94 100644
--- a/comfy_api_nodes/apis/recraft.py
+++ b/comfy_api_nodes/apis/recraft.py
@@ -198,11 +198,6 @@ dict_recraft_substyles_v3 = {
}
-class RecraftModel(str, Enum):
- recraftv3 = 'recraftv3'
- recraftv2 = 'recraftv2'
-
-
class RecraftImageSize(str, Enum):
res_1024x1024 = '1024x1024'
res_1365x1024 = '1365x1024'
@@ -221,6 +216,41 @@ class RecraftImageSize(str, Enum):
res_1707x1024 = '1707x1024'
+RECRAFT_V4_SIZES = [
+ "1024x1024",
+ "1536x768",
+ "768x1536",
+ "1280x832",
+ "832x1280",
+ "1216x896",
+ "896x1216",
+ "1152x896",
+ "896x1152",
+ "832x1344",
+ "1280x896",
+ "896x1280",
+ "1344x768",
+ "768x1344",
+]
+
+RECRAFT_V4_PRO_SIZES = [
+ "2048x2048",
+ "3072x1536",
+ "1536x3072",
+ "2560x1664",
+ "1664x2560",
+ "2432x1792",
+ "1792x2432",
+ "2304x1792",
+ "1792x2304",
+ "1664x2688",
+ "1434x1024",
+ "1024x1434",
+ "2560x1792",
+ "1792x2560",
+]
+
+
class RecraftColorObject(BaseModel):
rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model')
@@ -234,17 +264,16 @@ class RecraftControlsObject(BaseModel):
class RecraftImageGenerationRequest(BaseModel):
prompt: str = Field(..., description='The text prompt describing the image to generate')
- size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
+ size: str | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
n: int = Field(..., description='The number of images to generate')
negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image')
- model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
+ model: str = Field(...)
style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input')
controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process')
style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID')
strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
random_seed: int | None = Field(None, description="Seed for video generation")
- # text_layout
class RecraftReturnedObject(BaseModel):
diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py
index d3a52bc1b..4044ee3ea 100644
--- a/comfy_api_nodes/nodes_bria.py
+++ b/comfy_api_nodes/nodes_bria.py
@@ -3,7 +3,11 @@ from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bria import (
BriaEditImageRequest,
- BriaResponse,
+ BriaRemoveBackgroundRequest,
+ BriaRemoveBackgroundResponse,
+ BriaRemoveVideoBackgroundRequest,
+ BriaRemoveVideoBackgroundResponse,
+ BriaImageEditResponse,
BriaStatusResponse,
InputModerationSettings,
)
@@ -11,10 +15,12 @@ from comfy_api_nodes.util import (
ApiEndpoint,
convert_mask_to_image,
download_url_to_image_tensor,
- get_number_of_images,
+ download_url_to_video_output,
poll_op,
sync_op,
- upload_images_to_comfyapi,
+ upload_image_to_comfyapi,
+ upload_video_to_comfyapi,
+ validate_video_duration,
)
@@ -73,21 +79,15 @@ class BriaImageEditNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"moderation",
options=[
+ IO.DynamicCombo.Option("false", []),
IO.DynamicCombo.Option(
"true",
[
- IO.Boolean.Input(
- "prompt_content_moderation", default=False
- ),
- IO.Boolean.Input(
- "visual_input_moderation", default=False
- ),
- IO.Boolean.Input(
- "visual_output_moderation", default=True
- ),
+ IO.Boolean.Input("prompt_content_moderation", default=False),
+ IO.Boolean.Input("visual_input_moderation", default=False),
+ IO.Boolean.Input("visual_output_moderation", default=True),
],
),
- IO.DynamicCombo.Option("false", []),
],
tooltip="Moderation settings",
),
@@ -127,50 +127,26 @@ class BriaImageEditNode(IO.ComfyNode):
mask: Input.Image | None = None,
) -> IO.NodeOutput:
if not prompt and not structured_prompt:
- raise ValueError(
- "One of prompt or structured_prompt is required to be non-empty."
- )
- if get_number_of_images(image) != 1:
- raise ValueError("Exactly one input image is required.")
+ raise ValueError("One of prompt or structured_prompt is required to be non-empty.")
mask_url = None
if mask is not None:
- mask_url = (
- await upload_images_to_comfyapi(
- cls,
- convert_mask_to_image(mask),
- max_images=1,
- mime_type="image/png",
- wait_label="Uploading mask",
- )
- )[0]
+ mask_url = await upload_image_to_comfyapi(cls, convert_mask_to_image(mask), wait_label="Uploading mask")
response = await sync_op(
cls,
ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"),
data=BriaEditImageRequest(
instruction=prompt if prompt else None,
structured_instruction=structured_prompt if structured_prompt else None,
- images=await upload_images_to_comfyapi(
- cls,
- image,
- max_images=1,
- mime_type="image/png",
- wait_label="Uploading image",
- ),
+ images=[await upload_image_to_comfyapi(cls, image, wait_label="Uploading image")],
mask=mask_url,
negative_prompt=negative_prompt if negative_prompt else None,
guidance_scale=guidance_scale,
seed=seed,
model_version=model,
steps_num=steps,
- prompt_content_moderation=moderation.get(
- "prompt_content_moderation", False
- ),
- visual_input_content_moderation=moderation.get(
- "visual_input_moderation", False
- ),
- visual_output_content_moderation=moderation.get(
- "visual_output_moderation", False
- ),
+ prompt_content_moderation=moderation.get("prompt_content_moderation", False),
+ visual_input_content_moderation=moderation.get("visual_input_moderation", False),
+ visual_output_content_moderation=moderation.get("visual_output_moderation", False),
),
response_model=BriaStatusResponse,
)
@@ -178,7 +154,7 @@ class BriaImageEditNode(IO.ComfyNode):
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
- response_model=BriaResponse,
+ response_model=BriaImageEditResponse,
)
return IO.NodeOutput(
await download_url_to_image_tensor(response.result.image_url),
@@ -186,11 +162,167 @@ class BriaImageEditNode(IO.ComfyNode):
)
+class BriaRemoveImageBackground(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="BriaRemoveImageBackground",
+ display_name="Bria Remove Image Background",
+ category="api node/image/Bria",
+ description="Remove the background from an image using Bria RMBG 2.0.",
+ inputs=[
+ IO.Image.Input("image"),
+ IO.DynamicCombo.Input(
+ "moderation",
+ options=[
+ IO.DynamicCombo.Option("false", []),
+ IO.DynamicCombo.Option(
+ "true",
+ [
+ IO.Boolean.Input("visual_input_moderation", default=False),
+ IO.Boolean.Input("visual_output_moderation", default=True),
+ ],
+ ),
+ ],
+ tooltip="Moderation settings",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[IO.Image.Output()],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ expr="""{"type":"usd","usd":0.018}""",
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ image: Input.Image,
+ moderation: dict,
+ seed: int,
+ ) -> IO.NodeOutput:
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/bria/v2/image/edit/remove_background", method="POST"),
+ data=BriaRemoveBackgroundRequest(
+ image=await upload_image_to_comfyapi(cls, image, wait_label="Uploading image"),
+ sync=False,
+ visual_input_content_moderation=moderation.get("visual_input_moderation", False),
+ visual_output_content_moderation=moderation.get("visual_output_moderation", False),
+ seed=seed,
+ ),
+ response_model=BriaStatusResponse,
+ )
+ response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
+ status_extractor=lambda r: r.status,
+ response_model=BriaRemoveBackgroundResponse,
+ )
+ return IO.NodeOutput(await download_url_to_image_tensor(response.result.image_url))
+
+
+class BriaRemoveVideoBackground(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="BriaRemoveVideoBackground",
+ display_name="Bria Remove Video Background",
+ category="api node/video/Bria",
+ description="Remove the background from a video using Bria. ",
+ inputs=[
+ IO.Video.Input("video"),
+ IO.Combo.Input(
+ "background_color",
+ options=[
+ "Black",
+ "White",
+ "Gray",
+ "Red",
+ "Green",
+ "Blue",
+ "Yellow",
+ "Cyan",
+ "Magenta",
+ "Orange",
+ ],
+ tooltip="Background color for the output video.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[IO.Video.Output()],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ video: Input.Video,
+ background_color: str,
+ seed: int,
+ ) -> IO.NodeOutput:
+ validate_video_duration(video, max_duration=60.0)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"),
+ data=BriaRemoveVideoBackgroundRequest(
+ video=await upload_video_to_comfyapi(cls, video),
+ background_color=background_color,
+ output_container_and_codec="mp4_h264",
+ seed=seed,
+ ),
+ response_model=BriaStatusResponse,
+ )
+ response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
+ status_extractor=lambda r: r.status,
+ response_model=BriaRemoveVideoBackgroundResponse,
+ )
+ return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
+
+
class BriaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
BriaImageEditNode,
+ BriaRemoveImageBackground,
+ BriaRemoveVideoBackground,
]
diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py
index e85766259..b69285be5 100644
--- a/comfy_api_nodes/nodes_gemini.py
+++ b/comfy_api_nodes/nodes_gemini.py
@@ -6,6 +6,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
import base64
import os
from enum import Enum
+from fnmatch import fnmatch
from io import BytesIO
from typing import Literal
@@ -119,6 +120,13 @@ async def create_image_parts(
return image_parts
+def _mime_matches(mime: GeminiMimeType | None, pattern: str) -> bool:
+ """Check if a MIME type matches a pattern. Supports fnmatch globs (e.g. 'image/*')."""
+ if mime is None:
+ return False
+ return fnmatch(mime.value, pattern)
+
+
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
"""
Filter response parts by their type.
@@ -151,9 +159,9 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
for part in candidate.content.parts:
if part_type == "text" and part.text:
parts.append(part)
- elif part.inlineData and part.inlineData.mimeType == part_type:
+ elif part.inlineData and _mime_matches(part.inlineData.mimeType, part_type):
parts.append(part)
- elif part.fileData and part.fileData.mimeType == part_type:
+ elif part.fileData and _mime_matches(part.fileData.mimeType, part_type):
parts.append(part)
if not parts and blocked_reasons:
@@ -178,7 +186,7 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = []
- parts = get_parts_by_type(response, "image/png")
+ parts = get_parts_by_type(response, "image/*")
for part in parts:
if part.inlineData:
image_data = base64.b64decode(part.inlineData.data)
@@ -629,7 +637,7 @@ class GeminiImage(IO.ComfyNode):
if not aspect_ratio:
aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December
- image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
+ image_config = GeminiImageConfig() if aspect_ratio == "auto" else GeminiImageConfig(aspectRatio=aspect_ratio)
if images is not None:
parts.extend(await create_image_parts(cls, images))
@@ -649,7 +657,7 @@ class GeminiImage(IO.ComfyNode):
],
generationConfig=GeminiImageGenerationConfig(
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
- imageConfig=None if aspect_ratio == "auto" else image_config,
+ imageConfig=image_config,
),
systemInstruction=gemini_system_prompt,
),
diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py
index 813a7c809..d1d9578ec 100644
--- a/comfy_api_nodes/nodes_hunyuan3d.py
+++ b/comfy_api_nodes/nodes_hunyuan3d.py
@@ -1,31 +1,48 @@
from typing_extensions import override
-from comfy_api.latest import IO, ComfyExtension, Input
+from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.hunyuan3d import (
Hunyuan3DViewImage,
InputGenerateType,
ResultFile3D,
+ TextureEditTaskRequest,
To3DProTaskCreateResponse,
To3DProTaskQueryRequest,
To3DProTaskRequest,
To3DProTaskResultResponse,
+ To3DUVFileInput,
+ To3DUVTaskRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_file_3d,
+ download_url_to_image_tensor,
downscale_image_tensor_by_max_side,
poll_op,
sync_op,
+ upload_3d_model_to_comfyapi,
upload_image_to_comfyapi,
validate_image_dimensions,
validate_string,
)
-def get_file_from_response(response_objs: list[ResultFile3D], file_type: str) -> ResultFile3D | None:
+def _is_tencent_rate_limited(status: int, body: object) -> bool:
+ return (
+ status == 400
+ and isinstance(body, dict)
+ and "RequestLimitExceeded" in str(body.get("Response", {}).get("Error", {}).get("Code", ""))
+ )
+
+
+def get_file_from_response(
+ response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
+) -> ResultFile3D | None:
for i in response_objs:
if i.Type.lower() == file_type.lower():
return i
+ if raise_if_not_found:
+ raise ValueError(f"'{file_type}' file type is not found in the response.")
return None
@@ -35,8 +52,9 @@ class TencentTextToModelNode(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TencentTextToModelNode",
- display_name="Hunyuan3D: Text to Model (Pro)",
+ display_name="Hunyuan3D: Text to Model",
category="api node/3d/Tencent",
+ essentials_category="3D",
inputs=[
IO.Combo.Input(
"model",
@@ -120,6 +138,7 @@ class TencentTextToModelNode(IO.ComfyNode):
EnablePBR=generate_type.get("pbr", None),
PolygonType=generate_type.get("polygon_type", None),
),
+ is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
@@ -131,11 +150,14 @@ class TencentTextToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
- glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
- obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
- file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
return IO.NodeOutput(
- file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
+ f"{task_id}.glb",
+ await download_url_to_file_3d(
+ get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
+ ),
+ await download_url_to_file_3d(
+ get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
+ ),
)
@@ -145,8 +167,9 @@ class TencentImageToModelNode(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TencentImageToModelNode",
- display_name="Hunyuan3D: Image(s) to Model (Pro)",
+ display_name="Hunyuan3D: Image(s) to Model",
category="api node/3d/Tencent",
+ essentials_category="3D",
inputs=[
IO.Combo.Input(
"model",
@@ -268,6 +291,7 @@ class TencentImageToModelNode(IO.ComfyNode):
EnablePBR=generate_type.get("pbr", None),
PolygonType=generate_type.get("polygon_type", None),
),
+ is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
@@ -279,11 +303,257 @@ class TencentImageToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
- glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
- obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
- file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
return IO.NodeOutput(
- file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
+ f"{task_id}.glb",
+ await download_url_to_file_3d(
+ get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
+ ),
+ await download_url_to_file_3d(
+ get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
+ ),
+ )
+
+
+class TencentModelTo3DUVNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="TencentModelTo3DUVNode",
+ display_name="Hunyuan3D: Model to UV",
+ category="api node/3d/Tencent",
+ description="Perform UV unfolding on a 3D model to generate UV texture. "
+ "Input model must have less than 30000 faces.",
+ inputs=[
+ IO.MultiType.Input(
+ "model_3d",
+ types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DFBX, IO.File3DAny],
+ tooltip="Input 3D model (GLB, OBJ, or FBX)",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=1,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[
+ IO.File3DOBJ.Output(display_name="OBJ"),
+ IO.File3DFBX.Output(display_name="FBX"),
+ IO.Image.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.2}'),
+ )
+
+ SUPPORTED_FORMATS = {"glb", "obj", "fbx"}
+
+ @classmethod
+ async def execute(
+ cls,
+ model_3d: Types.File3D,
+ seed: int,
+ ) -> IO.NodeOutput:
+ _ = seed
+ file_format = model_3d.format.lower()
+ if file_format not in cls.SUPPORTED_FORMATS:
+ raise ValueError(
+ f"Unsupported file format: '{file_format}'. "
+ f"Supported formats: {', '.join(sorted(cls.SUPPORTED_FORMATS))}."
+ )
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"),
+ response_model=To3DProTaskCreateResponse,
+ data=To3DUVTaskRequest(
+ File=To3DUVFileInput(
+ Type=file_format.upper(),
+ Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format),
+ )
+ ),
+ is_rate_limited=_is_tencent_rate_limited,
+ )
+ if response.Error:
+ raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
+ result = await poll_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv/query", method="POST"),
+ data=To3DProTaskQueryRequest(JobId=response.JobId),
+ response_model=To3DProTaskResultResponse,
+ status_extractor=lambda r: r.Status,
+ )
+ return IO.NodeOutput(
+ await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
+ await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
+ await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url),
+ )
+
+
+class Tencent3DTextureEditNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="Tencent3DTextureEditNode",
+ display_name="Hunyuan3D: 3D Texture Edit",
+ category="api node/3d/Tencent",
+ description="After inputting the 3D model, perform 3D model texture redrawing.",
+ inputs=[
+ IO.MultiType.Input(
+ "model_3d",
+ types=[IO.File3DFBX, IO.File3DAny],
+ tooltip="3D model in FBX format. Model should have less than 100000 faces.",
+ ),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ default="",
+ tooltip="Describes texture editing. Supports up to 1024 UTF-8 characters.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ expr="""{"type":"usd","usd": 0.6}""",
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_3d: Types.File3D,
+ prompt: str,
+ seed: int,
+ ) -> IO.NodeOutput:
+ _ = seed
+ file_format = model_3d.format.lower()
+ if file_format != "fbx":
+ raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
+ validate_string(prompt, field_name="prompt", min_length=1, max_length=1024)
+ model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"),
+ response_model=To3DProTaskCreateResponse,
+ data=TextureEditTaskRequest(
+ File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
+ Prompt=prompt,
+ EnablePBR=True,
+ ),
+ is_rate_limited=_is_tencent_rate_limited,
+ )
+ if response.Error:
+ raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
+
+ result = await poll_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit/query", method="POST"),
+ data=To3DProTaskQueryRequest(JobId=response.JobId),
+ response_model=To3DProTaskResultResponse,
+ status_extractor=lambda r: r.Status,
+ )
+ return IO.NodeOutput(
+ await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
+ await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
+ )
+
+
+class Tencent3DPartNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="Tencent3DPartNode",
+ display_name="Hunyuan3D: 3D Part",
+ category="api node/3d/Tencent",
+ description="Automatically perform component identification and generation based on the model structure.",
+ inputs=[
+ IO.MultiType.Input(
+ "model_3d",
+ types=[IO.File3DFBX, IO.File3DAny],
+ tooltip="3D model in FBX format. Model should have less than 30000 faces.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[
+ IO.File3DFBX.Output(display_name="FBX"),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.6}'),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_3d: Types.File3D,
+ seed: int,
+ ) -> IO.NodeOutput:
+ _ = seed
+ file_format = model_3d.format.lower()
+ if file_format != "fbx":
+ raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
+ model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"),
+ response_model=To3DProTaskCreateResponse,
+ data=To3DUVTaskRequest(
+ File=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
+ ),
+ is_rate_limited=_is_tencent_rate_limited,
+ )
+ if response.Error:
+ raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
+ result = await poll_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part/query", method="POST"),
+ data=To3DProTaskQueryRequest(JobId=response.JobId),
+ response_model=To3DProTaskResultResponse,
+ status_extractor=lambda r: r.Status,
+ )
+ return IO.NodeOutput(
+ await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
)
@@ -293,6 +563,9 @@ class TencentHunyuan3DExtension(ComfyExtension):
return [
TencentTextToModelNode,
TencentImageToModelNode,
+ # TencentModelTo3DUVNode,
+ # Tencent3DTextureEditNode,
+ Tencent3DPartNode,
]
diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py
index f3995a342..fa0f9e87c 100644
--- a/comfy_api_nodes/nodes_kling.py
+++ b/comfy_api_nodes/nodes_kling.py
@@ -38,7 +38,6 @@ from comfy_api_nodes.apis import (
KlingImageGenerationsRequest,
KlingImageGenerationsResponse,
KlingImageGenImageReferenceType,
- KlingImageGenModelName,
KlingImageGenAspectRatio,
KlingVideoEffectsRequest,
KlingVideoEffectsResponse,
@@ -52,6 +51,7 @@ from comfy_api_nodes.apis import (
from comfy_api_nodes.apis.kling import (
ImageToVideoWithAudioRequest,
MotionControlRequest,
+ MultiPromptEntry,
OmniImageParamImage,
OmniParamImage,
OmniParamVideo,
@@ -71,6 +71,7 @@ from comfy_api_nodes.util import (
sync_op,
tensor_to_base64_string,
upload_audio_to_comfyapi,
+ upload_image_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_image_aspect_ratio,
@@ -80,6 +81,31 @@ from comfy_api_nodes.util import (
validate_video_duration,
)
+
+def _generate_storyboard_inputs(count: int) -> list:
+ inputs = []
+ for i in range(1, count + 1):
+ inputs.extend(
+ [
+ IO.String.Input(
+ f"storyboard_{i}_prompt",
+ multiline=True,
+ default="",
+ tooltip=f"Prompt for storyboard segment {i}. Max 512 characters.",
+ ),
+ IO.Int.Input(
+ f"storyboard_{i}_duration",
+ default=4,
+ min=1,
+ max=15,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip=f"Duration for storyboard segment {i} in seconds.",
+ ),
+ ]
+ )
+ return inputs
+
+
KLING_API_VERSION = "v1"
PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video"
PATH_IMAGE_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/image2video"
@@ -820,20 +846,48 @@ class OmniProTextToVideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProTextToVideoNode",
- display_name="Kling Omni Text to Video (Pro)",
+ display_name="Kling 3.0 Omni Text to Video",
category="api node/video/Kling",
description="Use text prompts to generate videos with the latest Kling model.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
- "This can include both positive and negative descriptions.",
+ "This can include both positive and negative descriptions. "
+ "Ignored when storyboards are enabled.",
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
- IO.Combo.Input("duration", options=[5, 10]),
+ IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
+ IO.DynamicCombo.Input(
+ "storyboards",
+ options=[
+ IO.DynamicCombo.Option("disabled", []),
+ IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)),
+ IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)),
+ IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)),
+ IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)),
+ IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)),
+ IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)),
+ ],
+ tooltip="Generate a series of video segments with individual prompts and durations. "
+ "Ignored for o1 model.",
+ optional=True,
+ ),
+ IO.Boolean.Input("generate_audio", default=False, optional=True),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Video.Output(),
@@ -845,11 +899,15 @@ class OmniProTextToVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
+ depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
- $rates := {"std": 0.084, "pro": 0.112};
+ $isV3 := $contains(widgets.model_name, "v3");
+ $audio := $isV3 and widgets.generate_audio;
+ $rates := $audio
+ ? {"std": 0.112, "pro": 0.14}
+ : {"std": 0.084, "pro": 0.112};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
@@ -864,8 +922,45 @@ class OmniProTextToVideoNode(IO.ComfyNode):
aspect_ratio: str,
duration: int,
resolution: str = "1080p",
+ storyboards: dict | None = None,
+ generate_audio: bool = False,
+ seed: int = 0,
) -> IO.NodeOutput:
- validate_string(prompt, min_length=1, max_length=2500)
+ _ = seed
+ if model_name == "kling-video-o1":
+ if duration not in (5, 10):
+ raise ValueError("kling-video-o1 only supports durations of 5 or 10 seconds.")
+ if generate_audio:
+ raise ValueError("kling-video-o1 does not support audio generation.")
+ stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
+ if stories_enabled and model_name == "kling-video-o1":
+ raise ValueError("kling-video-o1 does not support storyboards.")
+ validate_string(prompt, strip_whitespace=True, min_length=0 if stories_enabled else 1, max_length=2500)
+
+ multi_shot = None
+ multi_prompt_list = None
+ if stories_enabled:
+ count = int(storyboards["storyboards"].split()[0])
+ multi_shot = True
+ multi_prompt_list = []
+ for i in range(1, count + 1):
+ sb_prompt = storyboards[f"storyboard_{i}_prompt"]
+ sb_duration = storyboards[f"storyboard_{i}_duration"]
+ validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512)
+ multi_prompt_list.append(
+ MultiPromptEntry(
+ index=i,
+ prompt=sb_prompt,
+ duration=str(sb_duration),
+ )
+ )
+ total_storyboard_duration = sum(int(e.duration) for e in multi_prompt_list)
+ if total_storyboard_duration != duration:
+ raise ValueError(
+ f"Total storyboard duration ({total_storyboard_duration}s) "
+ f"must equal the global duration ({duration}s)."
+ )
+
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
@@ -876,6 +971,10 @@ class OmniProTextToVideoNode(IO.ComfyNode):
aspect_ratio=aspect_ratio,
duration=str(duration),
mode="pro" if resolution == "1080p" else "std",
+ multi_shot=multi_shot,
+ multi_prompt=multi_prompt_list,
+ shot_type="customize" if multi_shot else None,
+ sound="on" if generate_audio else "off",
),
)
return await finish_omni_video_task(cls, response)
@@ -887,24 +986,26 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProFirstLastFrameNode",
- display_name="Kling Omni First-Last-Frame to Video (Pro)",
+ display_name="Kling 3.0 Omni First-Last-Frame to Video",
category="api node/video/Kling",
description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
- "This can include both positive and negative descriptions.",
+ "This can include both positive and negative descriptions. "
+ "Ignored when storyboards are enabled.",
),
- IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider),
+ IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider),
IO.Image.Input("first_frame"),
IO.Image.Input(
"end_frame",
optional=True,
tooltip="An optional end frame for the video. "
- "This cannot be used simultaneously with 'reference_images'.",
+ "This cannot be used simultaneously with 'reference_images'. "
+ "Does not work with storyboards.",
),
IO.Image.Input(
"reference_images",
@@ -912,6 +1013,38 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
tooltip="Up to 6 additional reference images.",
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
+ IO.DynamicCombo.Input(
+ "storyboards",
+ options=[
+ IO.DynamicCombo.Option("disabled", []),
+ IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)),
+ IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)),
+ IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)),
+ IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)),
+ IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)),
+ IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)),
+ ],
+ tooltip="Generate a series of video segments with individual prompts and durations. "
+ "Only supported for kling-v3-omni.",
+ optional=True,
+ ),
+ IO.Boolean.Input(
+ "generate_audio",
+ default=False,
+ optional=True,
+ tooltip="Generate audio for the video. Only supported for kling-v3-omni.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Video.Output(),
@@ -923,11 +1056,15 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
+ depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
- $rates := {"std": 0.084, "pro": 0.112};
+ $isV3 := $contains(widgets.model_name, "v3");
+ $audio := $isV3 and widgets.generate_audio;
+ $rates := $audio
+ ? {"std": 0.112, "pro": 0.14}
+ : {"std": 0.084, "pro": 0.112};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
@@ -944,15 +1081,59 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
end_frame: Input.Image | None = None,
reference_images: Input.Image | None = None,
resolution: str = "1080p",
+ storyboards: dict | None = None,
+ generate_audio: bool = False,
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
+ if model_name == "kling-video-o1":
+ if duration > 10:
+ raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.")
+ if generate_audio:
+ raise ValueError("kling-video-o1 does not support audio generation.")
+ stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
+ if stories_enabled and model_name == "kling-video-o1":
+ raise ValueError("kling-video-o1 does not support storyboards.")
prompt = normalize_omni_prompt_references(prompt)
- validate_string(prompt, min_length=1, max_length=2500)
+ validate_string(prompt, strip_whitespace=True, min_length=0 if stories_enabled else 1, max_length=2500)
if end_frame is not None and reference_images is not None:
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
- if duration not in (5, 10) and end_frame is None and reference_images is None:
+ if end_frame is not None and stories_enabled:
+ raise ValueError("The 'end_frame' input cannot be used simultaneously with storyboards.")
+ if (
+ model_name == "kling-video-o1"
+ and duration not in (5, 10)
+ and end_frame is None
+ and reference_images is None
+ ):
raise ValueError(
"Duration is only supported for 5 or 10 seconds if there is no end frame or reference images."
)
+
+ multi_shot = None
+ multi_prompt_list = None
+ if stories_enabled:
+ count = int(storyboards["storyboards"].split()[0])
+ multi_shot = True
+ multi_prompt_list = []
+ for i in range(1, count + 1):
+ sb_prompt = storyboards[f"storyboard_{i}_prompt"]
+ sb_duration = storyboards[f"storyboard_{i}_duration"]
+ validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512)
+ multi_prompt_list.append(
+ MultiPromptEntry(
+ index=i,
+ prompt=sb_prompt,
+ duration=str(sb_duration),
+ )
+ )
+ total_storyboard_duration = sum(int(e.duration) for e in multi_prompt_list)
+ if total_storyboard_duration != duration:
+ raise ValueError(
+ f"Total storyboard duration ({total_storyboard_duration}s) "
+ f"must equal the global duration ({duration}s)."
+ )
+
validate_image_dimensions(first_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
image_list: list[OmniParamImage] = [
@@ -988,6 +1169,10 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
duration=str(duration),
image_list=image_list,
mode="pro" if resolution == "1080p" else "std",
+ sound="on" if generate_audio else "off",
+ multi_shot=multi_shot,
+ multi_prompt=multi_prompt_list,
+ shot_type="customize" if multi_shot else None,
),
)
return await finish_omni_video_task(cls, response)
@@ -999,24 +1184,57 @@ class OmniProImageToVideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProImageToVideoNode",
- display_name="Kling Omni Image to Video (Pro)",
+ display_name="Kling 3.0 Omni Image to Video",
category="api node/video/Kling",
description="Use up to 7 reference images to generate a video with the latest Kling model.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
- "This can include both positive and negative descriptions.",
+ "This can include both positive and negative descriptions. "
+ "Ignored when storyboards are enabled.",
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
- IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
+ IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider),
IO.Image.Input(
"reference_images",
tooltip="Up to 7 reference images.",
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
+ IO.DynamicCombo.Input(
+ "storyboards",
+ options=[
+ IO.DynamicCombo.Option("disabled", []),
+ IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)),
+ IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)),
+ IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)),
+ IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)),
+ IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)),
+ IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)),
+ ],
+ tooltip="Generate a series of video segments with individual prompts and durations. "
+ "Only supported for kling-v3-omni.",
+ optional=True,
+ ),
+ IO.Boolean.Input(
+ "generate_audio",
+ default=False,
+ optional=True,
+ tooltip="Generate audio for the video. Only supported for kling-v3-omni.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Video.Output(),
@@ -1028,11 +1246,15 @@ class OmniProImageToVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
+ depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
- $rates := {"std": 0.084, "pro": 0.112};
+ $isV3 := $contains(widgets.model_name, "v3");
+ $audio := $isV3 and widgets.generate_audio;
+ $rates := $audio
+ ? {"std": 0.112, "pro": 0.14}
+ : {"std": 0.084, "pro": 0.112};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
@@ -1048,9 +1270,46 @@ class OmniProImageToVideoNode(IO.ComfyNode):
duration: int,
reference_images: Input.Image,
resolution: str = "1080p",
+ storyboards: dict | None = None,
+ generate_audio: bool = False,
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
+ if model_name == "kling-video-o1":
+ if duration > 10:
+ raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.")
+ if generate_audio:
+ raise ValueError("kling-video-o1 does not support audio generation.")
+ stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
+ if stories_enabled and model_name == "kling-video-o1":
+ raise ValueError("kling-video-o1 does not support storyboards.")
prompt = normalize_omni_prompt_references(prompt)
- validate_string(prompt, min_length=1, max_length=2500)
+ validate_string(prompt, strip_whitespace=True, min_length=0 if stories_enabled else 1, max_length=2500)
+
+ multi_shot = None
+ multi_prompt_list = None
+ if stories_enabled:
+ count = int(storyboards["storyboards"].split()[0])
+ multi_shot = True
+ multi_prompt_list = []
+ for i in range(1, count + 1):
+ sb_prompt = storyboards[f"storyboard_{i}_prompt"]
+ sb_duration = storyboards[f"storyboard_{i}_duration"]
+ validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512)
+ multi_prompt_list.append(
+ MultiPromptEntry(
+ index=i,
+ prompt=sb_prompt,
+ duration=str(sb_duration),
+ )
+ )
+ total_storyboard_duration = sum(int(e.duration) for e in multi_prompt_list)
+ if total_storyboard_duration != duration:
+ raise ValueError(
+ f"Total storyboard duration ({total_storyboard_duration}s) "
+ f"must equal the global duration ({duration}s)."
+ )
+
if get_number_of_images(reference_images) > 7:
raise ValueError("The maximum number of reference images is 7.")
for i in reference_images:
@@ -1070,6 +1329,10 @@ class OmniProImageToVideoNode(IO.ComfyNode):
duration=str(duration),
image_list=image_list,
mode="pro" if resolution == "1080p" else "std",
+ sound="on" if generate_audio else "off",
+ multi_shot=multi_shot,
+ multi_prompt=multi_prompt_list,
+ shot_type="customize" if multi_shot else None,
),
)
return await finish_omni_video_task(cls, response)
@@ -1081,11 +1344,11 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProVideoToVideoNode",
- display_name="Kling Omni Video to Video (Pro)",
+ display_name="Kling 3.0 Omni Video to Video",
category="api node/video/Kling",
description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
@@ -1102,6 +1365,17 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
optional=True,
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Video.Output(),
@@ -1135,7 +1409,9 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
keep_original_sound: bool,
reference_images: Input.Image | None = None,
resolution: str = "1080p",
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
@@ -1179,11 +1455,11 @@ class OmniProEditVideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProEditVideoNode",
- display_name="Kling Omni Edit Video (Pro)",
+ display_name="Kling 3.0 Omni Edit Video",
category="api node/video/Kling",
description="Edit an existing video with the latest model from Kling.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
@@ -1198,6 +1474,17 @@ class OmniProEditVideoNode(IO.ComfyNode):
optional=True,
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Video.Output(),
@@ -1229,7 +1516,9 @@ class OmniProEditVideoNode(IO.ComfyNode):
keep_original_sound: bool,
reference_images: Input.Image | None = None,
resolution: str = "1080p",
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(video, min_duration=3.0, max_duration=10.05)
@@ -1273,27 +1562,43 @@ class OmniProImageNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProImageNode",
- display_name="Kling Omni Image (Pro)",
+ display_name="Kling 3.0 Omni Image",
category="api node/image/Kling",
description="Create or edit images with the latest model from Kling.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-image-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the image content. "
"This can include both positive and negative descriptions.",
),
- IO.Combo.Input("resolution", options=["1K", "2K"]),
+ IO.Combo.Input("resolution", options=["1K", "2K", "4K"]),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"],
),
+ IO.Combo.Input(
+ "series_amount",
+ options=["disabled", "2", "3", "4", "5", "6", "7", "8", "9"],
+ tooltip="Generate a series of images. Not supported for kling-image-o1.",
+ ),
IO.Image.Input(
"reference_images",
tooltip="Up to 10 additional reference images.",
optional=True,
),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Image.Output(),
@@ -1305,7 +1610,16 @@ class OmniProImageNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- expr="""{"type":"usd","usd":0.028}""",
+ depends_on=IO.PriceBadgeDepends(widgets=["resolution", "series_amount", "model_name"]),
+ expr="""
+ (
+ $prices := {"1k": 0.028, "2k": 0.028, "4k": 0.056};
+ $base := $lookup($prices, widgets.resolution);
+ $isO1 := widgets.model_name = "kling-image-o1";
+ $mult := ($isO1 or widgets.series_amount = "disabled") ? 1 : $number(widgets.series_amount);
+ {"type":"usd","usd": $base * $mult}
+ )
+ """,
),
)
@@ -1316,8 +1630,13 @@ class OmniProImageNode(IO.ComfyNode):
prompt: str,
resolution: str,
aspect_ratio: str,
+ series_amount: str = "disabled",
reference_images: Input.Image | None = None,
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
+ if model_name == "kling-image-o1" and resolution == "4K":
+ raise ValueError("4K resolution is not supported for kling-image-o1 model.")
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
image_list: list[OmniImageParamImage] = []
@@ -1329,6 +1648,9 @@ class OmniProImageNode(IO.ComfyNode):
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniImageParamImage(image=i))
+ use_series = series_amount != "disabled"
+ if use_series and model_name == "kling-image-o1":
+ raise ValueError("kling-image-o1 does not support series generation.")
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
@@ -1339,6 +1661,8 @@ class OmniProImageNode(IO.ComfyNode):
resolution=resolution.lower(),
aspect_ratio=aspect_ratio,
image_list=image_list if image_list else None,
+ result_type="series" if use_series else None,
+ series_amount=int(series_amount) if use_series else None,
),
)
if response.code:
@@ -1351,7 +1675,9 @@ class OmniProImageNode(IO.ComfyNode):
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
- return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
+ images = final_response.data.task_result.series_images or final_response.data.task_result.images
+ tensors = [await download_url_to_image_tensor(img.url) for img in images]
+ return IO.NodeOutput(torch.cat(tensors, dim=0))
class KlingCameraControlT2VNode(IO.ComfyNode):
@@ -1936,6 +2262,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
node_id="KlingLipSyncAudioToVideoNode",
display_name="Kling Lip Sync Video with Audio",
category="api node/video/Kling",
+ essentials_category="Video Generation",
description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
inputs=[
IO.Video.Input("video"),
@@ -2120,7 +2447,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImageGenerationNode",
- display_name="Kling Image Generation",
+ display_name="Kling 3.0 Image",
category="api node/image/Kling",
description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.",
inputs=[
@@ -2151,11 +2478,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
tooltip="Subject reference similarity",
advanced=True,
),
- IO.Combo.Input(
- "model_name",
- options=[i.value for i in KlingImageGenModelName],
- default="kling-v2",
- ),
+ IO.Combo.Input("model_name", options=["kling-v3", "kling-v2", "kling-v1-5"]),
IO.Combo.Input(
"aspect_ratio",
options=[i.value for i in KlingImageGenAspectRatio],
@@ -2169,6 +2492,17 @@ class KlingImageGenerationNode(IO.ComfyNode):
tooltip="Number of generated images",
),
IO.Image.Input("image", optional=True),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Image.Output(),
@@ -2187,7 +2521,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
$base :=
$contains($m,"kling-v1-5")
? (inputs.image.connected ? 0.028 : 0.014)
- : ($contains($m,"kling-v1") ? 0.0035 : 0.014);
+ : $contains($m,"kling-v3") ? 0.028 : 0.014;
{"type":"usd","usd": $base * widgets.n}
)
""",
@@ -2197,7 +2531,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
- model_name: KlingImageGenModelName,
+ model_name: str,
prompt: str,
negative_prompt: str,
image_type: KlingImageGenImageReferenceType,
@@ -2206,17 +2540,11 @@ class KlingImageGenerationNode(IO.ComfyNode):
n: int,
aspect_ratio: KlingImageGenAspectRatio,
image: torch.Tensor | None = None,
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN)
-
- if image is None:
- image_type = None
- elif model_name == KlingImageGenModelName.kling_v1:
- raise ValueError(f"The model {KlingImageGenModelName.kling_v1.value} does not support reference images.")
- else:
- image = tensor_to_base64_string(image)
-
task_creation_response = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_GENERATIONS, method="POST"),
@@ -2225,8 +2553,8 @@ class KlingImageGenerationNode(IO.ComfyNode):
model_name=model_name,
prompt=prompt,
negative_prompt=negative_prompt,
- image=image,
- image_reference=image_type,
+ image=tensor_to_base64_string(image) if image is not None else None,
+ image_reference=image_type if image is not None else None,
image_fidelity=image_fidelity,
human_fidelity=human_fidelity,
n=n,
@@ -2256,7 +2584,7 @@ class TextToVideoWithAudio(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingTextToVideoWithAudio",
- display_name="Kling Text to Video with Audio",
+ display_name="Kling 2.6 Text to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
@@ -2324,7 +2652,7 @@ class ImageToVideoWithAudio(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImageToVideoWithAudio",
- display_name="Kling Image(First Frame) to Video with Audio",
+ display_name="Kling 2.6 Image(First Frame) to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
@@ -2482,6 +2810,335 @@ class MotionControl(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
+class KlingVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingVideoNode",
+ display_name="Kling 3.0 Video",
+ category="api node/video/Kling",
+ description="Generate videos with Kling V3. "
+ "Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.",
+ inputs=[
+ IO.DynamicCombo.Input(
+ "multi_shot",
+ options=[
+ IO.DynamicCombo.Option(
+ "disabled",
+ [
+ IO.String.Input("prompt", multiline=True, default=""),
+ IO.String.Input("negative_prompt", multiline=True, default=""),
+ IO.Int.Input(
+ "duration",
+ default=5,
+ min=3,
+ max=15,
+ display_mode=IO.NumberDisplay.slider,
+ ),
+ ],
+ ),
+ IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)),
+ IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)),
+ IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)),
+ IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)),
+ IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)),
+ IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)),
+ ],
+ tooltip="Generate a series of video segments with individual prompts and durations.",
+ ),
+ IO.Boolean.Input("generate_audio", default=True),
+ IO.DynamicCombo.Input(
+ "model",
+ options=[
+ IO.DynamicCombo.Option(
+ "kling-v3",
+ [
+ IO.Combo.Input("resolution", options=["1080p", "720p"]),
+ IO.Combo.Input(
+ "aspect_ratio",
+ options=["16:9", "9:16", "1:1"],
+ tooltip="Ignored in image-to-video mode.",
+ ),
+ ],
+ ),
+ ],
+ tooltip="Model and generation settings.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ IO.Image.Input(
+ "start_frame",
+ optional=True,
+ tooltip="Optional start frame image. When connected, switches to image-to-video mode.",
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(
+ widgets=[
+ "model.resolution",
+ "generate_audio",
+ "multi_shot",
+ "multi_shot.duration",
+ "multi_shot.storyboard_1_duration",
+ "multi_shot.storyboard_2_duration",
+ "multi_shot.storyboard_3_duration",
+ "multi_shot.storyboard_4_duration",
+ "multi_shot.storyboard_5_duration",
+ "multi_shot.storyboard_6_duration",
+ ],
+ ),
+ expr="""
+ (
+ $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}};
+ $res := $lookup(widgets, "model.resolution");
+ $audio := widgets.generate_audio ? "on" : "off";
+ $rate := $lookup($lookup($rates, $res), $audio);
+ $ms := widgets.multi_shot;
+ $isSb := $ms != "disabled";
+ $n := $isSb ? $number($substring($ms, 0, 1)) : 0;
+ $d1 := $lookup(widgets, "multi_shot.storyboard_1_duration");
+ $d2 := $n >= 2 ? $lookup(widgets, "multi_shot.storyboard_2_duration") : 0;
+ $d3 := $n >= 3 ? $lookup(widgets, "multi_shot.storyboard_3_duration") : 0;
+ $d4 := $n >= 4 ? $lookup(widgets, "multi_shot.storyboard_4_duration") : 0;
+ $d5 := $n >= 5 ? $lookup(widgets, "multi_shot.storyboard_5_duration") : 0;
+ $d6 := $n >= 6 ? $lookup(widgets, "multi_shot.storyboard_6_duration") : 0;
+ $dur := $isSb ? $d1 + $d2 + $d3 + $d4 + $d5 + $d6 : $lookup(widgets, "multi_shot.duration");
+ {"type":"usd","usd": $rate * $dur}
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ multi_shot: dict,
+ generate_audio: bool,
+ model: dict,
+ seed: int,
+ start_frame: Input.Image | None = None,
+ ) -> IO.NodeOutput:
+ _ = seed
+ mode = "pro" if model["resolution"] == "1080p" else "std"
+ custom_multi_shot = False
+ if multi_shot["multi_shot"] == "disabled":
+ shot_type = None
+ else:
+ shot_type = "customize"
+ custom_multi_shot = True
+
+ multi_prompt_list = None
+ if shot_type == "customize":
+ count = int(multi_shot["multi_shot"].split()[0])
+ multi_prompt_list = []
+ for i in range(1, count + 1):
+ sb_prompt = multi_shot[f"storyboard_{i}_prompt"]
+ sb_duration = multi_shot[f"storyboard_{i}_duration"]
+ validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512)
+ multi_prompt_list.append(
+ MultiPromptEntry(
+ index=i,
+ prompt=sb_prompt,
+ duration=str(sb_duration),
+ )
+ )
+ duration = sum(int(e.duration) for e in multi_prompt_list)
+ if duration < 3 or duration > 15:
+ raise ValueError(
+ f"Total storyboard duration ({duration}s) must be between 3 and 15 seconds."
+ )
+ else:
+ duration = multi_shot["duration"]
+ validate_string(multi_shot["prompt"], min_length=1, max_length=2500)
+
+ if start_frame is not None:
+ validate_image_dimensions(start_frame, min_width=300, min_height=300)
+ validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
+ image_url = await upload_image_to_comfyapi(cls, start_frame, wait_label="Uploading start frame")
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
+ response_model=TaskStatusResponse,
+ data=ImageToVideoWithAudioRequest(
+ model_name=model["model"],
+ image=image_url,
+ prompt=None if custom_multi_shot else multi_shot["prompt"],
+ negative_prompt=None if custom_multi_shot else multi_shot["negative_prompt"],
+ mode=mode,
+ duration=str(duration),
+ sound="on" if generate_audio else "off",
+ multi_shot=True if shot_type else None,
+ multi_prompt=multi_prompt_list,
+ shot_type=shot_type,
+ ),
+ )
+ poll_path = f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"
+ else:
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
+ response_model=TaskStatusResponse,
+ data=TextToVideoWithAudioRequest(
+ model_name=model["model"],
+ aspect_ratio=model["aspect_ratio"],
+ prompt=None if custom_multi_shot else multi_shot["prompt"],
+ negative_prompt=None if custom_multi_shot else multi_shot["negative_prompt"],
+ mode=mode,
+ duration=str(duration),
+ sound="on" if generate_audio else "off",
+ multi_shot=True if shot_type else None,
+ multi_prompt=multi_prompt_list,
+ shot_type=shot_type,
+ ),
+ )
+ poll_path = f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"
+
+ if response.code:
+ raise RuntimeError(
+ f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
+ )
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path=poll_path),
+ response_model=TaskStatusResponse,
+ status_extractor=lambda r: (r.data.task_status if r.data else None),
+ )
+ return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
+
+
+class KlingFirstLastFrameNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingFirstLastFrameNode",
+ display_name="Kling 3.0 First-Last-Frame to Video",
+ category="api node/video/Kling",
+ description="Generate videos with Kling V3 using first and last frames.",
+ inputs=[
+ IO.String.Input("prompt", multiline=True, default=""),
+ IO.Int.Input(
+ "duration",
+ default=5,
+ min=3,
+ max=15,
+ display_mode=IO.NumberDisplay.slider,
+ ),
+ IO.Image.Input("first_frame"),
+ IO.Image.Input("end_frame"),
+ IO.Boolean.Input("generate_audio", default=True),
+ IO.DynamicCombo.Input(
+ "model",
+ options=[
+ IO.DynamicCombo.Option(
+ "kling-v3",
+ [
+ IO.Combo.Input("resolution", options=["1080p", "720p"]),
+ ],
+ ),
+ ],
+ tooltip="Model and generation settings.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(
+ widgets=["model.resolution", "generate_audio", "duration"],
+ ),
+ expr="""
+ (
+ $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}};
+ $res := $lookup(widgets, "model.resolution");
+ $audio := widgets.generate_audio ? "on" : "off";
+ $rate := $lookup($lookup($rates, $res), $audio);
+ {"type":"usd","usd": $rate * widgets.duration}
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ prompt: str,
+ duration: int,
+ first_frame: Input.Image,
+ end_frame: Input.Image,
+ generate_audio: bool,
+ model: dict,
+ seed: int,
+ ) -> IO.NodeOutput:
+ _ = seed
+ validate_string(prompt, min_length=1, max_length=2500)
+ validate_image_dimensions(first_frame, min_width=300, min_height=300)
+ validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
+ validate_image_dimensions(end_frame, min_width=300, min_height=300)
+ validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
+ image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame")
+ image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame")
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
+ response_model=TaskStatusResponse,
+ data=ImageToVideoWithAudioRequest(
+ model_name=model["model"],
+ image=image_url,
+ image_tail=image_tail_url,
+ prompt=prompt,
+ mode="pro" if model["resolution"] == "1080p" else "std",
+ duration=str(duration),
+ sound="on" if generate_audio else "off",
+ ),
+ )
+ if response.code:
+ raise RuntimeError(
+ f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
+ )
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
+ response_model=TaskStatusResponse,
+ status_extractor=lambda r: (r.data.task_status if r.data else None),
+ )
+ return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
+
+
class KlingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -2508,6 +3165,8 @@ class KlingExtension(ComfyExtension):
TextToVideoWithAudio,
ImageToVideoWithAudio,
MotionControl,
+ KlingVideoNode,
+ KlingFirstLastFrameNode,
]
diff --git a/comfy_api_nodes/nodes_magnific.py b/comfy_api_nodes/nodes_magnific.py
index 7e1bd0989..0f53208d4 100644
--- a/comfy_api_nodes/nodes_magnific.py
+++ b/comfy_api_nodes/nodes_magnific.py
@@ -30,6 +30,30 @@ from comfy_api_nodes.util import (
validate_image_dimensions,
)
+_EUR_TO_USD = 1.19
+
+
+def _tier_price_eur(megapixels: float) -> float:
+ """Price in EUR for a single Magnific upscaling step based on input megapixels."""
+ if megapixels <= 1.3:
+ return 0.143
+ if megapixels <= 3.0:
+ return 0.286
+ if megapixels <= 6.4:
+ return 0.429
+ return 1.716
+
+
+def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float:
+ """Calculate total Magnific upscale price in USD for given input dimensions and scale factor."""
+ num_steps = int(math.log2(scale))
+ total_eur = 0.0
+ pixels = width * height
+ for _ in range(num_steps):
+ total_eur += _tier_price_eur(pixels / 1_000_000)
+ pixels *= 4
+ return round(total_eur * _EUR_TO_USD, 2)
+
class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
@classmethod
@@ -105,11 +129,20 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
+ depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]),
expr="""
(
- $max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
- {"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
+ $ad := widgets.auto_downscale;
+ $mins := $ad
+ ? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515}
+ : {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
+ $maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187};
+ {
+ "type": "range_usd",
+ "min_usd": $lookup($mins, widgets.scale_factor),
+ "max_usd": $lookup($maxs, widgets.scale_factor),
+ "format": { "approximate": true }
+ }
)
""",
),
@@ -170,6 +203,10 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
f"Use a smaller input image or lower scale factor."
)
+ final_height, final_width = get_image_dimensions(image)
+ actual_scale = int(scale_factor.rstrip("x"))
+ price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale)
+
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
@@ -191,6 +228,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
response_model=TaskResponse,
status_extractor=lambda x: x.status,
+ price_extractor=lambda _: price_usd,
poll_interval=10.0,
max_poll_attempts=480,
)
@@ -260,8 +298,14 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
expr="""
(
- $max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
- {"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
+ $mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
+ $maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06};
+ {
+ "type": "range_usd",
+ "min_usd": $lookup($mins, widgets.scale_factor),
+ "max_usd": $lookup($maxs, widgets.scale_factor),
+ "format": { "approximate": true }
+ }
)
""",
),
@@ -324,6 +368,9 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
f"Use a smaller input image or lower scale factor."
)
+ final_height, final_width = get_image_dimensions(image)
+ price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale)
+
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
@@ -342,6 +389,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
response_model=TaskResponse,
status_extractor=lambda x: x.status,
+ price_extractor=lambda _: price_usd,
poll_interval=10.0,
max_poll_attempts=480,
)
@@ -885,8 +933,8 @@ class MagnificExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
- # MagnificImageUpscalerCreativeNode,
- # MagnificImageUpscalerPreciseV2Node,
+ MagnificImageUpscalerCreativeNode,
+ MagnificImageUpscalerPreciseV2Node,
MagnificImageStyleTransferNode,
MagnificImageRelightNode,
MagnificImageSkinEnhancerNode,
diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py
index 08315fa2b..78a230529 100644
--- a/comfy_api_nodes/nodes_moonvalley.py
+++ b/comfy_api_nodes/nodes_moonvalley.py
@@ -219,8 +219,8 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
- default=33,
- min=1,
+ default=80,
+ min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
max=100,
step=1,
tooltip="Number of denoising steps",
@@ -340,8 +340,8 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
- default=33,
- min=1,
+ default=60,
+ min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24)
max=100,
step=1,
display_mode=IO.NumberDisplay.number,
@@ -370,7 +370,7 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
video: Input.Video | None = None,
control_type: str = "Motion Transfer",
motion_intensity: int | None = 100,
- steps=33,
+ steps=60,
prompt_adherence=4.5,
) -> IO.NodeOutput:
validated_video = validate_video_to_video_input(video)
@@ -465,8 +465,8 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
- default=33,
- min=1,
+ default=80,
+ min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
max=100,
step=1,
tooltip="Inference steps",
diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py
index 879bdb1da..4ee896fa8 100644
--- a/comfy_api_nodes/nodes_openai.py
+++ b/comfy_api_nodes/nodes_openai.py
@@ -43,7 +43,6 @@ class SupportedOpenAIModel(str, Enum):
o1 = "o1"
o3 = "o3"
o1_pro = "o1-pro"
- gpt_4o = "gpt-4o"
gpt_4_1 = "gpt-4.1"
gpt_4_1_mini = "gpt-4.1-mini"
gpt_4_1_nano = "gpt-4.1-nano"
@@ -576,6 +575,7 @@ class OpenAIChatNode(IO.ComfyNode):
node_id="OpenAIChatNode",
display_name="OpenAI ChatGPT",
category="api node/text/OpenAI",
+ essentials_category="Text Generation",
description="Generate text responses from an OpenAI model.",
inputs=[
IO.String.Input(
@@ -650,11 +650,6 @@ class OpenAIChatNode(IO.ComfyNode):
"usd": [0.01, 0.04],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
- : $contains($m, "gpt-4o") ? {
- "type": "list_usd",
- "usd": [0.0025, 0.01],
- "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
- }
: $contains($m, "gpt-4.1-nano") ? {
"type": "list_usd",
"usd": [0.0001, 0.0004],
diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py
index 3a1f32263..4d1d508fa 100644
--- a/comfy_api_nodes/nodes_recraft.py
+++ b/comfy_api_nodes/nodes_recraft.py
@@ -1,5 +1,4 @@
from io import BytesIO
-from typing import Optional, Union
import aiohttp
import torch
@@ -9,6 +8,8 @@ from typing_extensions import override
from comfy.utils import ProgressBar
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.recraft import (
+ RECRAFT_V4_PRO_SIZES,
+ RECRAFT_V4_SIZES,
RecraftColor,
RecraftColorChain,
RecraftControls,
@@ -18,7 +19,6 @@ from comfy_api_nodes.apis.recraft import (
RecraftImageGenerationResponse,
RecraftImageSize,
RecraftIO,
- RecraftModel,
RecraftStyle,
RecraftStyleV3,
get_v3_substyles,
@@ -39,7 +39,7 @@ async def handle_recraft_file_request(
cls: type[IO.ComfyNode],
image: torch.Tensor,
path: str,
- mask: Optional[torch.Tensor] = None,
+ mask: torch.Tensor | None = None,
total_pixels: int = 4096 * 4096,
timeout: int = 1024,
request=None,
@@ -73,11 +73,11 @@ async def handle_recraft_file_request(
def recraft_multipart_parser(
data,
parent_key=None,
- formatter: Optional[type[callable]] = None,
- converted_to_check: Optional[list[list]] = None,
+ formatter: type[callable] | None = None,
+ converted_to_check: list[list] | None = None,
is_list: bool = False,
return_mode: str = "formdata", # "dict" | "formdata"
-) -> Union[dict, aiohttp.FormData]:
+) -> dict | aiohttp.FormData:
"""
Formats data such that multipart/form-data will work with aiohttp library when both files and data are present.
@@ -309,7 +309,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
node_id="RecraftStyleV3InfiniteStyleLibrary",
display_name="Recraft Style - Infinite Style Library",
category="api node/image/Recraft",
- description="Select style based on preexisting UUID from Recraft's Infinite Style Library.",
+ description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.",
inputs=[
IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."),
],
@@ -485,7 +485,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
data=RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt,
- model=RecraftModel.recraftv3,
+ model="recraftv3",
size=size,
n=n,
style=recraft_style.style,
@@ -598,7 +598,7 @@ class RecraftImageToImageNode(IO.ComfyNode):
request = RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt,
- model=RecraftModel.recraftv3,
+ model="recraftv3",
n=n,
strength=round(strength, 2),
style=recraft_style.style,
@@ -698,7 +698,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
request = RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt,
- model=RecraftModel.recraftv3,
+ model="recraftv3",
n=n,
style=recraft_style.style,
substyle=recraft_style.substyle,
@@ -810,7 +810,7 @@ class RecraftTextToVectorNode(IO.ComfyNode):
data=RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt,
- model=RecraftModel.recraftv3,
+ model="recraftv3",
size=size,
n=n,
style=recraft_style.style,
@@ -933,7 +933,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
request = RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt,
- model=RecraftModel.recraftv3,
+ model="recraftv3",
n=n,
style=recraft_style.style,
substyle=recraft_style.substyle,
@@ -963,6 +963,7 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode):
node_id="RecraftRemoveBackgroundNode",
display_name="Recraft Remove Background",
category="api node/image/Recraft",
+ essentials_category="Image Tools",
description="Remove background from image, and return processed image and mask.",
inputs=[
IO.Image.Input("image"),
@@ -1078,6 +1079,252 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
)
+class RecraftV4TextToImageNode(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="RecraftV4TextToImageNode",
+ display_name="Recraft V4 Text to Image",
+ category="api node/image/Recraft",
+ description="Generates images using Recraft V4 or V4 Pro models.",
+ inputs=[
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="Prompt for the image generation. Maximum 10,000 characters.",
+ ),
+ IO.String.Input(
+ "negative_prompt",
+ multiline=True,
+ tooltip="An optional text description of undesired elements on an image.",
+ ),
+ IO.DynamicCombo.Input(
+ "model",
+ options=[
+ IO.DynamicCombo.Option(
+ "recraftv4",
+ [
+ IO.Combo.Input(
+ "size",
+ options=RECRAFT_V4_SIZES,
+ default="1024x1024",
+ tooltip="The size of the generated image.",
+ ),
+ ],
+ ),
+ IO.DynamicCombo.Option(
+ "recraftv4_pro",
+ [
+ IO.Combo.Input(
+ "size",
+ options=RECRAFT_V4_PRO_SIZES,
+ default="2048x2048",
+ tooltip="The size of the generated image.",
+ ),
+ ],
+ ),
+ ],
+ tooltip="The model to use for generation.",
+ ),
+ IO.Int.Input(
+ "n",
+ default=1,
+ min=1,
+ max=6,
+ tooltip="The number of images to generate.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=0xFFFFFFFFFFFFFFFF,
+ control_after_generate=True,
+ tooltip="Seed to determine if node should re-run; "
+ "actual results are nondeterministic regardless of seed.",
+ ),
+ IO.Custom(RecraftIO.CONTROLS).Input(
+ "recraft_controls",
+ tooltip="Optional additional controls over the generation via the Recraft Controls node.",
+ optional=True,
+ ),
+ ],
+ outputs=[
+ IO.Image.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]),
+ expr="""
+ (
+ $prices := {"recraftv4": 0.04, "recraftv4_pro": 0.25};
+ {"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n}
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ prompt: str,
+ negative_prompt: str,
+ model: dict,
+ n: int,
+ seed: int,
+ recraft_controls: RecraftControls | None = None,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"),
+ response_model=RecraftImageGenerationResponse,
+ data=RecraftImageGenerationRequest(
+ prompt=prompt,
+ negative_prompt=negative_prompt if negative_prompt else None,
+ model=model["model"],
+ size=model["size"],
+ n=n,
+ controls=recraft_controls.create_api_model() if recraft_controls else None,
+ ),
+ max_retries=1,
+ )
+ images = []
+ for data in response.data:
+ with handle_recraft_image_output():
+ image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024))
+ if len(image.shape) < 4:
+ image = image.unsqueeze(0)
+ images.append(image)
+ return IO.NodeOutput(torch.cat(images, dim=0))
+
+
+class RecraftV4TextToVectorNode(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="RecraftV4TextToVectorNode",
+ display_name="Recraft V4 Text to Vector",
+ category="api node/image/Recraft",
+ description="Generates SVG using Recraft V4 or V4 Pro models.",
+ inputs=[
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="Prompt for the image generation. Maximum 10,000 characters.",
+ ),
+ IO.String.Input(
+ "negative_prompt",
+ multiline=True,
+ tooltip="An optional text description of undesired elements on an image.",
+ ),
+ IO.DynamicCombo.Input(
+ "model",
+ options=[
+ IO.DynamicCombo.Option(
+ "recraftv4",
+ [
+ IO.Combo.Input(
+ "size",
+ options=RECRAFT_V4_SIZES,
+ default="1024x1024",
+ tooltip="The size of the generated image.",
+ ),
+ ],
+ ),
+ IO.DynamicCombo.Option(
+ "recraftv4_pro",
+ [
+ IO.Combo.Input(
+ "size",
+ options=RECRAFT_V4_PRO_SIZES,
+ default="2048x2048",
+ tooltip="The size of the generated image.",
+ ),
+ ],
+ ),
+ ],
+ tooltip="The model to use for generation.",
+ ),
+ IO.Int.Input(
+ "n",
+ default=1,
+ min=1,
+ max=6,
+ tooltip="The number of images to generate.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=0xFFFFFFFFFFFFFFFF,
+ control_after_generate=True,
+ tooltip="Seed to determine if node should re-run; "
+ "actual results are nondeterministic regardless of seed.",
+ ),
+ IO.Custom(RecraftIO.CONTROLS).Input(
+ "recraft_controls",
+ tooltip="Optional additional controls over the generation via the Recraft Controls node.",
+ optional=True,
+ ),
+ ],
+ outputs=[
+ IO.SVG.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]),
+ expr="""
+ (
+ $prices := {"recraftv4": 0.08, "recraftv4_pro": 0.30};
+ {"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n}
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ prompt: str,
+ negative_prompt: str,
+ model: dict,
+ n: int,
+ seed: int,
+ recraft_controls: RecraftControls | None = None,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"),
+ response_model=RecraftImageGenerationResponse,
+ data=RecraftImageGenerationRequest(
+ prompt=prompt,
+ negative_prompt=negative_prompt if negative_prompt else None,
+ model=model["model"],
+ size=model["size"],
+ n=n,
+ style="vector_illustration",
+ substyle=None,
+ controls=recraft_controls.create_api_model() if recraft_controls else None,
+ ),
+ max_retries=1,
+ )
+ svg_data = []
+ for data in response.data:
+ svg_data.append(await download_url_as_bytesio(data.url, timeout=1024))
+ return IO.NodeOutput(SVG(svg_data))
+
+
class RecraftExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -1098,6 +1345,8 @@ class RecraftExtension(ComfyExtension):
RecraftCreateStyleNode,
RecraftColorRGBNode,
RecraftControlsNode,
+ RecraftV4TextToImageNode,
+ RecraftV4TextToVectorNode,
]
diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py
index 492f5c2a1..2b829b8db 100644
--- a/comfy_api_nodes/nodes_rodin.py
+++ b/comfy_api_nodes/nodes_rodin.py
@@ -505,6 +505,9 @@ class Rodin3D_Gen2(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
+ price_badge=IO.PriceBadge(
+ expr="""{"type":"usd","usd":0.4}""",
+ ),
)
@classmethod
diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py
index d6798bd9c..9ef13c83b 100644
--- a/comfy_api_nodes/nodes_stability.py
+++ b/comfy_api_nodes/nodes_stability.py
@@ -631,6 +631,7 @@ class StabilityTextToAudio(IO.ComfyNode):
node_id="StabilityTextToAudio",
display_name="Stability AI Text To Audio",
category="api node/audio/Stability AI",
+ essentials_category="Audio",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(
diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py
index e5bf97c85..f04407eb5 100644
--- a/comfy_api_nodes/nodes_vidu.py
+++ b/comfy_api_nodes/nodes_vidu.py
@@ -54,6 +54,7 @@ async def execute_task(
response_model=TaskStatusResponse,
status_extractor=lambda r: r.state,
progress_extractor=lambda r: r.progress,
+ price_extractor=lambda r: r.credits * 0.005 if r.credits is not None else None,
max_poll_attempts=max_poll_attempts,
)
if not response.creations:
@@ -1320,6 +1321,36 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
),
],
),
+ IO.DynamicCombo.Option(
+ "viduq3-turbo",
+ [
+ IO.Combo.Input(
+ "aspect_ratio",
+ options=["16:9", "9:16", "3:4", "4:3", "1:1"],
+ tooltip="The aspect ratio of the output video.",
+ ),
+ IO.Combo.Input(
+ "resolution",
+ options=["720p", "1080p"],
+ tooltip="Resolution of the output video.",
+ ),
+ IO.Int.Input(
+ "duration",
+ default=5,
+ min=1,
+ max=16,
+ step=1,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip="Duration of the output video in seconds.",
+ ),
+ IO.Boolean.Input(
+ "audio",
+ default=False,
+ tooltip="When enabled, outputs video with sound "
+ "(including dialogue and sound effects).",
+ ),
+ ],
+ ),
],
tooltip="Model to use for video generation.",
),
@@ -1348,13 +1379,20 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]),
+ depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
expr="""
(
$res := $lookup(widgets, "model.resolution");
- $base := $lookup({"720p": 0.075, "1080p": 0.1}, $res);
- $perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res);
- {"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)}
+ $d := $lookup(widgets, "model.duration");
+ $contains(widgets.model, "turbo")
+ ? (
+ $rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
+ {"type":"usd","usd": $rate * $d}
+ )
+ : (
+ $rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res);
+ {"type":"usd","usd": $rate * $d}
+ )
)
""",
),
@@ -1423,6 +1461,31 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
),
],
),
+ IO.DynamicCombo.Option(
+ "viduq3-turbo",
+ [
+ IO.Combo.Input(
+ "resolution",
+ options=["720p", "1080p"],
+ tooltip="Resolution of the output video.",
+ ),
+ IO.Int.Input(
+ "duration",
+ default=5,
+ min=1,
+ max=16,
+ step=1,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip="Duration of the output video in seconds.",
+ ),
+ IO.Boolean.Input(
+ "audio",
+ default=False,
+ tooltip="When enabled, outputs video with sound "
+ "(including dialogue and sound effects).",
+ ),
+ ],
+ ),
],
tooltip="Model to use for video generation.",
),
@@ -1456,13 +1519,20 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]),
+ depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
expr="""
(
$res := $lookup(widgets, "model.resolution");
- $base := $lookup({"720p": 0.075, "1080p": 0.275, "2k": 0.35}, $res);
- $perSec := $lookup({"720p": 0.05, "1080p": 0.075, "2k": 0.075}, $res);
- {"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)}
+ $d := $lookup(widgets, "model.duration");
+ $contains(widgets.model, "turbo")
+ ? (
+ $rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
+ {"type":"usd","usd": $rate * $d}
+ )
+ : (
+ $rate := $lookup({"720p": 0.15, "1080p": 0.16, "2k": 0.2}, $res);
+ {"type":"usd","usd": $rate * $d}
+ )
)
""",
),
@@ -1495,6 +1565,145 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
+class Vidu3StartEndToVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="Vidu3StartEndToVideoNode",
+ display_name="Vidu Q3 Start/End Frame-to-Video Generation",
+ category="api node/video/Vidu",
+ description="Generate a video from a start frame, an end frame, and a prompt.",
+ inputs=[
+ IO.DynamicCombo.Input(
+ "model",
+ options=[
+ IO.DynamicCombo.Option(
+ "viduq3-pro",
+ [
+ IO.Combo.Input(
+ "resolution",
+ options=["720p", "1080p"],
+ tooltip="Resolution of the output video.",
+ ),
+ IO.Int.Input(
+ "duration",
+ default=5,
+ min=1,
+ max=16,
+ step=1,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip="Duration of the output video in seconds.",
+ ),
+ IO.Boolean.Input(
+ "audio",
+ default=False,
+ tooltip="When enabled, outputs video with sound "
+ "(including dialogue and sound effects).",
+ ),
+ ],
+ ),
+ IO.DynamicCombo.Option(
+ "viduq3-turbo",
+ [
+ IO.Combo.Input(
+ "resolution",
+ options=["720p", "1080p"],
+ tooltip="Resolution of the output video.",
+ ),
+ IO.Int.Input(
+ "duration",
+ default=5,
+ min=1,
+ max=16,
+ step=1,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip="Duration of the output video in seconds.",
+ ),
+ IO.Boolean.Input(
+ "audio",
+ default=False,
+ tooltip="When enabled, outputs video with sound "
+ "(including dialogue and sound effects).",
+ ),
+ ],
+ ),
+ ],
+ tooltip="Model to use for video generation.",
+ ),
+ IO.Image.Input("first_frame"),
+ IO.Image.Input("end_frame"),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="Prompt description (max 2000 characters).",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=1,
+ min=0,
+ max=2147483647,
+ step=1,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
+ expr="""
+ (
+ $res := $lookup(widgets, "model.resolution");
+ $d := $lookup(widgets, "model.duration");
+ $contains(widgets.model, "turbo")
+ ? (
+ $rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
+ {"type":"usd","usd": $rate * $d}
+ )
+ : (
+ $rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res);
+ {"type":"usd","usd": $rate * $d}
+ )
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model: dict,
+ first_frame: Input.Image,
+ end_frame: Input.Image,
+ prompt: str,
+ seed: int,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, max_length=2000)
+ validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
+ payload = TaskCreationRequest(
+ model=model["model"],
+ prompt=prompt,
+ duration=model["duration"],
+ seed=seed,
+ resolution=model["resolution"],
+ audio=model["audio"],
+ images=[
+ (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
+ for frame in (first_frame, end_frame)
+ ],
+ )
+ results = await execute_task(cls, VIDU_START_END_VIDEO, payload)
+ return IO.NodeOutput(await download_url_to_video_output(results[0].url))
+
+
class ViduExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -1511,6 +1720,7 @@ class ViduExtension(ComfyExtension):
ViduMultiFrameVideoNode,
Vidu3TextToVideoNode,
Vidu3ImageToVideoNode,
+ Vidu3StartEndToVideoNode,
]
diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py
index 18b020eef..f8a0ba8af 100644
--- a/comfy_api_nodes/util/__init__.py
+++ b/comfy_api_nodes/util/__init__.py
@@ -33,6 +33,7 @@ from .download_helpers import (
download_url_to_video_output,
)
from .upload_helpers import (
+ upload_3d_model_to_comfyapi,
upload_audio_to_comfyapi,
upload_file_to_comfyapi,
upload_image_to_comfyapi,
@@ -62,6 +63,7 @@ __all__ = [
"sync_op",
"sync_op_raw",
# Upload helpers
+ "upload_3d_model_to_comfyapi",
"upload_audio_to_comfyapi",
"upload_file_to_comfyapi",
"upload_image_to_comfyapi",
diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py
index 8a1259506..94886af7b 100644
--- a/comfy_api_nodes/util/client.py
+++ b/comfy_api_nodes/util/client.py
@@ -57,6 +57,7 @@ class _RequestConfig:
files: dict[str, Any] | list[tuple[str, Any]] | None
multipart_parser: Callable | None
max_retries: int
+ max_retries_on_rate_limit: int
retry_delay: float
retry_backoff: float
wait_label: str = "Waiting"
@@ -65,6 +66,7 @@ class _RequestConfig:
final_label_on_success: str | None = "Completed"
progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
+ is_rate_limited: Callable[[int, Any], bool] | None = None
@dataclass
@@ -78,7 +80,7 @@ class _PollUIState:
active_since: float | None = None # start time of current active interval (None if queued)
-_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
+_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
@@ -103,6 +105,8 @@ async def sync_op(
final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None,
monitor_progress: bool = True,
+ max_retries_on_rate_limit: int = 16,
+ is_rate_limited: Callable[[int, Any], bool] | None = None,
) -> M:
raw = await sync_op_raw(
cls,
@@ -122,6 +126,8 @@ async def sync_op(
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
monitor_progress=monitor_progress,
+ max_retries_on_rate_limit=max_retries_on_rate_limit,
+ is_rate_limited=is_rate_limited,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
@@ -143,9 +149,9 @@ async def poll_op(
poll_interval: float = 5.0,
max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0,
- max_retries_per_poll: int = 3,
+ max_retries_per_poll: int = 10,
retry_delay_per_poll: float = 1.0,
- retry_backoff_per_poll: float = 2.0,
+ retry_backoff_per_poll: float = 1.4,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
@@ -194,6 +200,8 @@ async def sync_op_raw(
final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None,
monitor_progress: bool = True,
+ max_retries_on_rate_limit: int = 16,
+ is_rate_limited: Callable[[int, Any], bool] | None = None,
) -> dict[str, Any] | bytes:
"""
Make a single network request.
@@ -222,6 +230,8 @@ async def sync_op_raw(
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
price_extractor=price_extractor,
+ max_retries_on_rate_limit=max_retries_on_rate_limit,
+ is_rate_limited=is_rate_limited,
)
return await _request_base(cfg, expect_binary=as_binary)
@@ -240,9 +250,9 @@ async def poll_op_raw(
poll_interval: float = 5.0,
max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0,
- max_retries_per_poll: int = 3,
+ max_retries_per_poll: int = 10,
retry_delay_per_poll: float = 1.0,
- retry_backoff_per_poll: float = 2.0,
+ retry_backoff_per_poll: float = 1.4,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
@@ -506,7 +516,7 @@ def _friendly_http_message(status: int, body: Any) -> str:
if status == 409:
return "There is a problem with your account. Please contact support@comfy.org."
if status == 429:
- return "Rate Limit Exceeded: Please try again later."
+ return "Rate Limit Exceeded: The server returned 429 after all retry attempts. Please wait and try again."
try:
if isinstance(body, dict):
err = body.get("error")
@@ -586,6 +596,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
attempt = 0
delay = cfg.retry_delay
+ rate_limit_attempts = 0
+ rate_limit_delay = cfg.retry_delay
operation_succeeded: bool = False
final_elapsed_seconds: int | None = None
extracted_price: float | None = None
@@ -653,17 +665,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload_headers["Content-Type"] = "application/json"
payload_kw["json"] = cfg.data or {}
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- request_headers=dict(payload_headers) if payload_headers else None,
- request_params=dict(params) if params else None,
- request_data=request_body_log,
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] request logging failed: %s", _log_e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ request_headers=dict(payload_headers) if payload_headers else None,
+ request_params=dict(params) if params else None,
+ request_data=request_body_log,
+ )
req_coro = sess.request(method, url, params=params, **payload_kw)
req_task = asyncio.create_task(req_coro)
@@ -688,41 +697,33 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
body = await resp.json()
except (ContentTypeError, json.JSONDecodeError):
body = await resp.text()
- if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
+ should_retry = False
+ wait_time = 0.0
+ retry_label = ""
+ is_rl = resp.status == 429 or (
+ cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
+ )
+ if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
+ rate_limit_attempts += 1
+ wait_time = min(rate_limit_delay, 30.0)
+ rate_limit_delay *= cfg.retry_backoff
+ retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
+ should_retry = True
+ elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
+ wait_time = delay
+ delay *= cfg.retry_backoff
+ retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
+ should_retry = True
+
+ if should_retry:
logging.warning(
- "HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
+ "HTTP %s %s -> %s. Waiting %.2fs (%s).",
method,
url,
resp.status,
- delay,
- attempt,
- cfg.max_retries,
+ wait_time,
+ retry_label,
)
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- response_status_code=resp.status,
- response_headers=dict(resp.headers),
- response_content=body,
- error_message=_friendly_http_message(resp.status, body),
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] response logging failed: %s", _log_e)
-
- await sleep_with_interrupt(
- delay,
- cfg.node_cls,
- cfg.wait_label if cfg.monitor_progress else None,
- start_time if cfg.monitor_progress else None,
- cfg.estimated_total,
- display_callback=_display_time_progress if cfg.monitor_progress else None,
- )
- delay *= cfg.retry_backoff
- continue
- msg = _friendly_http_message(resp.status, body)
- try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
@@ -730,10 +731,27 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
- error_message=msg,
+ error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)",
)
- except Exception as _log_e:
- logging.debug("[DEBUG] response logging failed: %s", _log_e)
+ await sleep_with_interrupt(
+ wait_time,
+ cfg.node_cls,
+ cfg.wait_label if cfg.monitor_progress else None,
+ start_time if cfg.monitor_progress else None,
+ cfg.estimated_total,
+ display_callback=_display_time_progress if cfg.monitor_progress else None,
+ )
+ continue
+ msg = _friendly_http_message(resp.status, body)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ response_status_code=resp.status,
+ response_headers=dict(resp.headers),
+ response_content=body,
+ error_message=msg,
+ )
raise Exception(msg)
if expect_binary:
@@ -753,17 +771,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
bytes_payload = bytes(buff)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- response_status_code=resp.status,
- response_headers=dict(resp.headers),
- response_content=bytes_payload,
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] response logging failed: %s", _log_e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ response_status_code=resp.status,
+ response_headers=dict(resp.headers),
+ response_content=bytes_payload,
+ )
return bytes_payload
else:
try:
@@ -780,45 +795,39 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- response_status_code=resp.status,
- response_headers=dict(resp.headers),
- response_content=response_content_to_log,
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] response logging failed: %s", _log_e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ response_status_code=resp.status,
+ response_headers=dict(resp.headers),
+ response_content=response_content_to_log,
+ )
return payload
except ProcessingInterrupted:
logging.debug("Polling was interrupted by user")
raise
except (ClientError, OSError) as e:
- if attempt <= cfg.max_retries:
+ if (attempt - rate_limit_attempts) <= cfg.max_retries:
logging.warning(
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
method,
url,
delay,
- attempt,
+ attempt - rate_limit_attempts,
cfg.max_retries,
str(e),
)
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- request_headers=dict(payload_headers) if payload_headers else None,
- request_params=dict(params) if params else None,
- request_data=request_body_log,
- error_message=f"{type(e).__name__}: {str(e)} (will retry)",
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] request error logging failed: %s", _log_e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ request_headers=dict(payload_headers) if payload_headers else None,
+ request_params=dict(params) if params else None,
+ request_data=request_body_log,
+ error_message=f"{type(e).__name__}: {str(e)} (will retry)",
+ )
await sleep_with_interrupt(
delay,
cfg.node_cls,
@@ -831,23 +840,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
continue
diag = await _diagnose_connectivity()
if not diag["internet_accessible"]:
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- request_headers=dict(payload_headers) if payload_headers else None,
- request_params=dict(params) if params else None,
- request_data=request_body_log,
- error_message=f"LocalNetworkError: {str(e)}",
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] final error logging failed: %s", _log_e)
- raise LocalNetworkError(
- "Unable to connect to the API server due to local network issues. "
- "Please check your internet connection and try again."
- ) from e
- try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
@@ -855,10 +847,21 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
- error_message=f"ApiServerError: {str(e)}",
+ error_message=f"LocalNetworkError: {str(e)}",
)
- except Exception as _log_e:
- logging.debug("[DEBUG] final error logging failed: %s", _log_e)
+ raise LocalNetworkError(
+ "Unable to connect to the API server due to local network issues. "
+ "Please check your internet connection and try again."
+ ) from e
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ request_headers=dict(payload_headers) if payload_headers else None,
+ request_params=dict(params) if params else None,
+ request_data=request_body_log,
+ error_message=f"ApiServerError: {str(e)}",
+ )
raise ApiServerError(
f"The API server at {default_base_url()} is currently unreachable. "
f"The service may be experiencing issues."
diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py
index 3e37e8a8c..82b6d22a5 100644
--- a/comfy_api_nodes/util/conversions.py
+++ b/comfy_api_nodes/util/conversions.py
@@ -57,7 +57,7 @@ def tensor_to_bytesio(
image: torch.Tensor,
*,
total_pixels: int | None = 2048 * 2048,
- mime_type: str = "image/png",
+ mime_type: str | None = "image/png",
) -> BytesIO:
"""Converts a torch.Tensor image to a named BytesIO object.
diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py
index 78bcf1fa1..aa588d038 100644
--- a/comfy_api_nodes/util/download_helpers.py
+++ b/comfy_api_nodes/util/download_helpers.py
@@ -167,27 +167,25 @@ async def download_url_to_bytesio(
with contextlib.suppress(Exception):
dest.seek(0)
- with contextlib.suppress(Exception):
- request_logger.log_request_response(
- operation_id=op_id,
- request_method="GET",
- request_url=url,
- response_status_code=resp.status,
- response_headers=dict(resp.headers),
- response_content=f"[streamed {written} bytes to dest]",
- )
+ request_logger.log_request_response(
+ operation_id=op_id,
+ request_method="GET",
+ request_url=url,
+ response_status_code=resp.status,
+ response_headers=dict(resp.headers),
+ response_content=f"[streamed {written} bytes to dest]",
+ )
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (ClientError, OSError) as e:
if attempt <= max_retries:
- with contextlib.suppress(Exception):
- request_logger.log_request_response(
- operation_id=op_id,
- request_method="GET",
- request_url=url,
- error_message=f"{type(e).__name__}: {str(e)} (will retry)",
- )
+ request_logger.log_request_response(
+ operation_id=op_id,
+ request_method="GET",
+ request_url=url,
+ error_message=f"{type(e).__name__}: {str(e)} (will retry)",
+ )
await sleep_with_interrupt(delay, cls, None, None, None)
delay *= retry_backoff
continue
diff --git a/comfy_api_nodes/util/request_logger.py b/comfy_api_nodes/util/request_logger.py
index e0cb4428d..fe0543d9b 100644
--- a/comfy_api_nodes/util/request_logger.py
+++ b/comfy_api_nodes/util/request_logger.py
@@ -8,7 +8,6 @@ from typing import Any
import folder_paths
-# Get the logger instance
logger = logging.getLogger(__name__)
@@ -91,38 +90,41 @@ def log_request_response(
Filenames are sanitized and length-limited for cross-platform safety.
If we still fail to write, we fall back to appending into api.log.
"""
- log_dir = get_log_directory()
- filepath = _build_log_filepath(log_dir, operation_id, request_url)
-
- log_content: list[str] = []
- log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
- log_content.append(f"Operation ID: {operation_id}")
- log_content.append("-" * 30 + " REQUEST " + "-" * 30)
- log_content.append(f"Method: {request_method}")
- log_content.append(f"URL: {request_url}")
- if request_headers:
- log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
- if request_params:
- log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
- if request_data is not None:
- log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
-
- log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
- if response_status_code is not None:
- log_content.append(f"Status Code: {response_status_code}")
- if response_headers:
- log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
- if response_content is not None:
- log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
- if error_message:
- log_content.append(f"Error:\n{error_message}")
-
try:
- with open(filepath, "w", encoding="utf-8") as f:
- f.write("\n".join(log_content))
- logger.debug("API log saved to: %s", filepath)
- except Exception as e:
- logger.error("Error writing API log to %s: %s", filepath, str(e))
+ log_dir = get_log_directory()
+ filepath = _build_log_filepath(log_dir, operation_id, request_url)
+
+ log_content: list[str] = []
+ log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
+ log_content.append(f"Operation ID: {operation_id}")
+ log_content.append("-" * 30 + " REQUEST " + "-" * 30)
+ log_content.append(f"Method: {request_method}")
+ log_content.append(f"URL: {request_url}")
+ if request_headers:
+ log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
+ if request_params:
+ log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
+ if request_data is not None:
+ log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
+
+ log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
+ if response_status_code is not None:
+ log_content.append(f"Status Code: {response_status_code}")
+ if response_headers:
+ log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
+ if response_content is not None:
+ log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
+ if error_message:
+ log_content.append(f"Error:\n{error_message}")
+
+ try:
+ with open(filepath, "w", encoding="utf-8") as f:
+ f.write("\n".join(log_content))
+ logger.debug("API log saved to: %s", filepath)
+ except Exception as e:
+ logger.error("Error writing API log to %s: %s", filepath, str(e))
+ except Exception as _log_e:
+ logging.debug("[DEBUG] log_request_response failed: %s", _log_e)
if __name__ == '__main__':
diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py
index 83d936ce1..6d1d107a1 100644
--- a/comfy_api_nodes/util/upload_helpers.py
+++ b/comfy_api_nodes/util/upload_helpers.py
@@ -164,6 +164,27 @@ async def upload_video_to_comfyapi(
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
+_3D_MIME_TYPES = {
+ "glb": "model/gltf-binary",
+ "obj": "model/obj",
+ "fbx": "application/octet-stream",
+}
+
+
+async def upload_3d_model_to_comfyapi(
+ cls: type[IO.ComfyNode],
+ model_3d: Types.File3D,
+ file_format: str,
+) -> str:
+ """Uploads a 3D model file to ComfyUI API and returns its download URL."""
+ return await upload_file_to_comfyapi(
+ cls,
+ model_3d.get_data(),
+ f"{uuid.uuid4()}.{file_format}",
+ _3D_MIME_TYPES.get(file_format, "application/octet-stream"),
+ )
+
+
async def upload_file_to_comfyapi(
cls: type[IO.ComfyNode],
file_bytes_io: BytesIO,
@@ -255,17 +276,14 @@ async def upload_file(
monitor_task = asyncio.create_task(_monitor())
sess: aiohttp.ClientSession | None = None
try:
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method="PUT",
- request_url=upload_url,
- request_headers=headers or None,
- request_params=None,
- request_data=f"[File data {len(data)} bytes]",
- )
- except Exception as e:
- logging.debug("[DEBUG] upload request logging failed: %s", e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method="PUT",
+ request_url=upload_url,
+ request_headers=headers or None,
+ request_params=None,
+ request_data=f"[File data {len(data)} bytes]",
+ )
sess = aiohttp.ClientSession(timeout=timeout)
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
@@ -311,31 +329,27 @@ async def upload_file(
delay *= retry_backoff
continue
raise Exception(f"Failed to upload (HTTP {resp.status}).")
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method="PUT",
- request_url=upload_url,
- response_status_code=resp.status,
- response_headers=dict(resp.headers),
- response_content="File uploaded successfully.",
- )
- except Exception as e:
- logging.debug("[DEBUG] upload response logging failed: %s", e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method="PUT",
+ request_url=upload_url,
+ response_status_code=resp.status,
+ response_headers=dict(resp.headers),
+ response_content="File uploaded successfully.",
+ )
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (aiohttp.ClientError, OSError) as e:
if attempt <= max_retries:
- with contextlib.suppress(Exception):
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method="PUT",
- request_url=upload_url,
- request_headers=headers or None,
- request_data=f"[File data {len(data)} bytes]",
- error_message=f"{type(e).__name__}: {str(e)} (will retry)",
- )
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method="PUT",
+ request_url=upload_url,
+ request_headers=headers or None,
+ request_data=f"[File data {len(data)} bytes]",
+ error_message=f"{type(e).__name__}: {str(e)} (will retry)",
+ )
await sleep_with_interrupt(
delay,
cls,
diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py
index bf091a448..370014fb6 100644
--- a/comfy_execution/jobs.py
+++ b/comfy_execution/jobs.py
@@ -20,10 +20,60 @@ class JobStatus:
# Media types that can be previewed in the frontend
-PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
+PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
# 3D file extensions for preview fallback (no dedicated media_type exists)
-THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
+THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
+
+
+def has_3d_extension(filename: str) -> bool:
+ lower = filename.lower()
+ return any(lower.endswith(ext) for ext in THREE_D_EXTENSIONS)
+
+
+def normalize_output_item(item):
+ """Normalize a single output list item for the jobs API.
+
+ Returns the normalized item, or None to exclude it.
+ String items with 3D extensions become {filename, type, subfolder} dicts.
+ """
+ if item is None:
+ return None
+ if isinstance(item, str):
+ if has_3d_extension(item):
+ return {'filename': item, 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
+ return None
+ if isinstance(item, dict):
+ return item
+ return None
+
+
+def normalize_outputs(outputs: dict) -> dict:
+ """Normalize raw node outputs for the jobs API.
+
+ Transforms string 3D filenames into file output dicts and removes
+ None items. All other items (non-3D strings, dicts, etc.) are
+ preserved as-is.
+ """
+ normalized = {}
+ for node_id, node_outputs in outputs.items():
+ if not isinstance(node_outputs, dict):
+ normalized[node_id] = node_outputs
+ continue
+ normalized_node = {}
+ for media_type, items in node_outputs.items():
+ if media_type == 'animated' or not isinstance(items, list):
+ normalized_node[media_type] = items
+ continue
+ normalized_items = []
+ for item in items:
+ if item is None:
+ continue
+ norm = normalize_output_item(item)
+ normalized_items.append(norm if norm is not None else item)
+ normalized_node[media_type] = normalized_items
+ normalized[node_id] = normalized_node
+ return normalized
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
@@ -45,9 +95,9 @@ def is_previewable(media_type: str, item: dict) -> bool:
Maintains backwards compatibility with existing logic.
Priority:
- 1. media_type is 'images', 'video', or 'audio'
+ 1. media_type is 'images', 'video', 'audio', or '3d'
2. format field starts with 'video/' or 'audio/'
- 3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
+ 3. filename has a 3D extension (.obj, .fbx, .gltf, .glb, .usdz)
"""
if media_type in PREVIEWABLE_MEDIA_TYPES:
return True
@@ -139,7 +189,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
})
if include_outputs:
- job['outputs'] = outputs
+ job['outputs'] = normalize_outputs(outputs)
job['execution_status'] = status_info
job['workflow'] = {
'prompt': prompt,
@@ -171,18 +221,23 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
continue
for item in items:
- count += 1
-
- if not isinstance(item, dict):
+ normalized = normalize_output_item(item)
+ if normalized is None:
continue
- if preview_output is None and is_previewable(media_type, item):
+ count += 1
+
+ if preview_output is not None:
+ continue
+
+ if isinstance(normalized, dict) and is_previewable(media_type, normalized):
enriched = {
- **item,
+ **normalized,
'nodeId': node_id,
- 'mediaType': media_type
}
- if item.get('type') == 'output':
+ if 'mediaType' not in normalized:
+ enriched['mediaType'] = media_type
+ if normalized.get('type') == 'output':
preview_output = enriched
elif fallback_preview is None:
fallback_preview = enriched
diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py
index dde5bbd2a..9cf84ab4d 100644
--- a/comfy_extras/nodes_ace.py
+++ b/comfy_extras/nodes_ace.py
@@ -49,13 +49,14 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
+ io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
],
outputs=[io.Conditioning.Output()],
)
@classmethod
- def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
- tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
+ def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
+ tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
conditioning = clip.encode_from_tokens_scheduled(tokens)
return io.NodeOutput(conditioning)
diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py
index 3462ecef9..43df0512f 100644
--- a/comfy_extras/nodes_audio.py
+++ b/comfy_extras/nodes_audio.py
@@ -159,6 +159,7 @@ class SaveAudio(IO.ComfyNode):
search_aliases=["export flac"],
display_name="Save Audio (FLAC)",
category="audio",
+ essentials_category="Audio",
inputs=[
IO.Audio.Input("audio"),
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
@@ -300,6 +301,7 @@ class LoadAudio(IO.ComfyNode):
search_aliases=["import audio", "open audio", "audio file"],
display_name="Load Audio",
category="audio",
+ essentials_category="Audio",
inputs=[
IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)),
],
@@ -700,6 +702,67 @@ class EmptyAudio(IO.ComfyNode):
create_empty_audio = execute # TODO: remove
+class AudioEqualizer3Band(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="AudioEqualizer3Band",
+ search_aliases=["eq", "bass boost", "treble boost", "equalizer"],
+ display_name="Audio Equalizer (3-Band)",
+ category="audio",
+ is_experimental=True,
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.Float.Input("low_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Low frequencies (Bass)"),
+ IO.Int.Input("low_freq", default=100, min=20, max=500, tooltip="Cutoff frequency for Low shelf"),
+ IO.Float.Input("mid_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Mid frequencies"),
+ IO.Int.Input("mid_freq", default=1000, min=200, max=4000, tooltip="Center frequency for Mids"),
+ IO.Float.Input("mid_q", default=0.707, min=0.1, max=10.0, step=0.1, tooltip="Q factor (bandwidth) for Mids"),
+ IO.Float.Input("high_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for High frequencies (Treble)"),
+ IO.Int.Input("high_freq", default=5000, min=1000, max=15000, tooltip="Cutoff frequency for High shelf"),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
+
+ @classmethod
+ def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput:
+ waveform = audio["waveform"]
+ sample_rate = audio["sample_rate"]
+ eq_waveform = waveform.clone()
+
+ # 1. Apply Low Shelf (Bass)
+ if low_gain_dB != 0:
+ eq_waveform = torchaudio.functional.bass_biquad(
+ eq_waveform,
+ sample_rate,
+ gain=low_gain_dB,
+ central_freq=float(low_freq),
+ Q=0.707
+ )
+
+ # 2. Apply Peaking EQ (Mids)
+ if mid_gain_dB != 0:
+ eq_waveform = torchaudio.functional.equalizer_biquad(
+ eq_waveform,
+ sample_rate,
+ center_freq=float(mid_freq),
+ gain=mid_gain_dB,
+ Q=mid_q
+ )
+
+ # 3. Apply High Shelf (Treble)
+ if high_gain_dB != 0:
+ eq_waveform = torchaudio.functional.treble_biquad(
+ eq_waveform,
+ sample_rate,
+ gain=high_gain_dB,
+ central_freq=float(high_freq),
+ Q=0.707
+ )
+
+ return IO.NodeOutput({"waveform": eq_waveform, "sample_rate": sample_rate})
+
+
class AudioExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -722,6 +785,7 @@ class AudioExtension(ComfyExtension):
AudioMerge,
AudioAdjustVolume,
EmptyAudio,
+ AudioEqualizer3Band,
]
async def comfy_entrypoint() -> AudioExtension:
diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py
index 6e0fadca5..956c1977c 100644
--- a/comfy_extras/nodes_canny.py
+++ b/comfy_extras/nodes_canny.py
@@ -12,6 +12,7 @@ class Canny(io.ComfyNode):
node_id="Canny",
search_aliases=["edge detection", "outline", "contour detection", "line art"],
category="image/preprocessors",
+ essentials_category="Image Tools",
inputs=[
io.Image.Input("image"),
io.Float.Input("low_threshold", default=0.4, min=0.01, max=0.99, step=0.01),
diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py
index 6ad0dc5d0..1e957c09b 100644
--- a/comfy_extras/nodes_custom_sampler.py
+++ b/comfy_extras/nodes_custom_sampler.py
@@ -622,6 +622,7 @@ class SamplerSASolver(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplerSASolver",
+ search_aliases=["sde"],
category="sampling/custom_sampling/samplers",
inputs=[
io.Model.Input("model"),
@@ -666,6 +667,7 @@ class SamplerSEEDS2(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplerSEEDS2",
+ search_aliases=["sde", "exp heun"],
category="sampling/custom_sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py
index e030148a8..923c2bb05 100644
--- a/comfy_extras/nodes_easycache.py
+++ b/comfy_extras/nodes_easycache.py
@@ -108,7 +108,7 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
if easycache.is_past_end_timestep(timestep):
return executor(*args, **kwargs)
- x: torch.Tensor = _extract_tensor(args[0], easycache.output_channels)
+ x: torch.Tensor = args[0][:, :easycache.output_channels]
# prepare next x_prev
next_x_prev = x
input_change = None
diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py
index 6f94cba45..df0c3e4b1 100644
--- a/comfy_extras/nodes_hunyuan3d.py
+++ b/comfy_extras/nodes_hunyuan3d.py
@@ -621,6 +621,7 @@ class SaveGLB(IO.ComfyNode):
display_name="Save 3D Model",
search_aliases=["export 3d model", "save mesh"],
category="3d",
+ essentials_category="Basics",
is_output_node=True,
inputs=[
IO.MultiType.Input(
diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py
index d27f49ece..983e2a92f 100644
--- a/comfy_extras/nodes_images.py
+++ b/comfy_extras/nodes_images.py
@@ -23,8 +23,10 @@ class ImageCrop(IO.ComfyNode):
return IO.Schema(
node_id="ImageCrop",
search_aliases=["trim"],
- display_name="Image Crop",
+ display_name="Image Crop (Deprecated)",
category="image/transform",
+ is_deprecated=True,
+ essentials_category="Image Tools",
inputs=[
IO.Image.Input("image"),
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
@@ -47,6 +49,57 @@ class ImageCrop(IO.ComfyNode):
crop = execute # TODO: remove
+class ImageCropV2(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ImageCropV2",
+ search_aliases=["trim"],
+ display_name="Image Crop",
+ category="image/transform",
+ inputs=[
+ IO.Image.Input("image"),
+ IO.BoundingBox.Input("crop_region", component="ImageCrop"),
+ ],
+ outputs=[IO.Image.Output()],
+ )
+
+ @classmethod
+ def execute(cls, image, crop_region) -> IO.NodeOutput:
+ x = crop_region.get("x", 0)
+ y = crop_region.get("y", 0)
+ width = crop_region.get("width", 512)
+ height = crop_region.get("height", 512)
+
+ x = min(x, image.shape[2] - 1)
+ y = min(y, image.shape[1] - 1)
+ to_x = width + x
+ to_y = height + y
+ img = image[:,y:to_y, x:to_x, :]
+ return IO.NodeOutput(img, ui=UI.PreviewImage(img))
+
+
+class BoundingBox(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="PrimitiveBoundingBox",
+ display_name="Bounding Box",
+ category="utils/primitive",
+ inputs=[
+ IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION),
+ IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION),
+ IO.Int.Input("width", default=512, min=1, max=MAX_RESOLUTION),
+ IO.Int.Input("height", default=512, min=1, max=MAX_RESOLUTION),
+ ],
+ outputs=[IO.BoundingBox.Output()],
+ )
+
+ @classmethod
+ def execute(cls, x, y, width, height) -> IO.NodeOutput:
+ return IO.NodeOutput({"x": x, "y": y, "width": width, "height": height})
+
+
class RepeatImageBatch(IO.ComfyNode):
@classmethod
def define_schema(cls):
@@ -537,6 +590,7 @@ class ImageRotate(IO.ComfyNode):
node_id="ImageRotate",
search_aliases=["turn", "flip orientation"],
category="image/transform",
+ essentials_category="Image Tools",
inputs=[
IO.Image.Input("image"),
IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]),
@@ -632,6 +686,8 @@ class ImagesExtension(ComfyExtension):
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ImageCrop,
+ ImageCropV2,
+ BoundingBox,
RepeatImageBatch,
ImageFromBatch,
ImageAddNoise,
diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py
index a83ec5fb7..8bb368dec 100644
--- a/comfy_extras/nodes_latent.py
+++ b/comfy_extras/nodes_latent.py
@@ -391,8 +391,9 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
normalized_latent = latent / latent_vector_magnitude
- mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
- std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
+ dims = list(range(1, latent_vector_magnitude.ndim))
+ mean = torch.mean(latent_vector_magnitude, dim=dims, keepdim=True)
+ std = torch.std(latent_vector_magnitude, dim=dims, keepdim=True)
top = (std * 5 + mean) * multiplier
diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py
index 9d70fd38b..9112bdd0a 100644
--- a/comfy_extras/nodes_load_3d.py
+++ b/comfy_extras/nodes_load_3d.py
@@ -31,6 +31,7 @@ class Load3D(IO.ComfyNode):
node_id="Load3D",
display_name="Load 3D & Animation",
category="3d",
+ essentials_category="Basics",
is_experimental=True,
inputs=[
IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model),
diff --git a/comfy_extras/nodes_lora_extract.py b/comfy_extras/nodes_lora_extract.py
index 29605c8d0..975f90f45 100644
--- a/comfy_extras/nodes_lora_extract.py
+++ b/comfy_extras/nodes_lora_extract.py
@@ -7,6 +7,7 @@ import logging
from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
+from tqdm.auto import trange
CLAMP_QUANTILE = 0.99
@@ -49,12 +50,22 @@ LORA_TYPES = {"standard": LORAType.STANDARD,
"full_diff": LORAType.FULL_DIFF}
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
- comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
+ comfy.model_management.load_models_gpu([model_diff])
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
- for k in sd:
- if k.endswith(".weight"):
+ sd_keys = list(sd.keys())
+ for index in trange(len(sd_keys), unit="weight"):
+ k = sd_keys[index]
+ op_keys = sd_keys[index].rsplit('.', 1)
+ if len(op_keys) < 2 or op_keys[1] not in ["weight", "bias"] or (op_keys[1] == "bias" and not bias_diff):
+ continue
+ op = comfy.utils.get_attr(model_diff.model, op_keys[0])
+ if hasattr(op, "comfy_cast_weights") and not getattr(op, "comfy_patched_weights", False):
+ weight_diff = model_diff.patch_weight_to_device(k, model_diff.load_device, return_weight=True)
+ else:
weight_diff = sd[k]
+
+ if op_keys[1] == "weight":
if lora_type == LORAType.STANDARD:
if weight_diff.ndim < 2:
if bias_diff:
@@ -69,8 +80,8 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
elif lora_type == LORAType.FULL_DIFF:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
- elif bias_diff and k.endswith(".bias"):
- output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
+ elif bias_diff and op_keys[1] == "bias":
+ output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = weight_diff.contiguous().half().cpu()
return output_sd
class LoraSave(io.ComfyNode):
diff --git a/comfy_extras/nodes_nag.py b/comfy_extras/nodes_nag.py
new file mode 100644
index 000000000..033e40eb9
--- /dev/null
+++ b/comfy_extras/nodes_nag.py
@@ -0,0 +1,99 @@
+import torch
+from comfy_api.latest import ComfyExtension, io
+from typing_extensions import override
+
+
+class NAGuidance(io.ComfyNode):
+ @classmethod
+ def define_schema(cls) -> io.Schema:
+ return io.Schema(
+ node_id="NAGuidance",
+ display_name="Normalized Attention Guidance",
+ description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
+ category="",
+ is_experimental=True,
+ inputs=[
+ io.Model.Input("model", tooltip="The model to apply NAG to."),
+ io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."),
+ io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."),
+ io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01),
+ # io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."),
+ # io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."),
+ ],
+ outputs=[
+ io.Model.Output(tooltip="The patched model with NAG enabled."),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput:
+ m = model.clone()
+
+ # sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent)
+ # sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent)
+
+ def nag_attention_output_patch(out, extra_options):
+ cond_or_uncond = extra_options.get("cond_or_uncond", None)
+ if cond_or_uncond is None:
+ return out
+
+ if not (1 in cond_or_uncond and 0 in cond_or_uncond):
+ return out
+
+ # sigma = extra_options.get("sigmas", None)
+ # if sigma is not None and len(sigma) > 0:
+ # sigma = sigma[0].item()
+ # if sigma > sigma_start or sigma < sigma_end:
+ # return out
+
+ img_slice = extra_options.get("img_slice", None)
+
+ if img_slice is not None:
+ orig_out = out
+ out = out[:, img_slice[0]:img_slice[1]] # only apply on img part
+
+ batch_size = out.shape[0]
+ half_size = batch_size // len(cond_or_uncond)
+
+ ind_neg = cond_or_uncond.index(1)
+ ind_pos = cond_or_uncond.index(0)
+ z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)]
+ z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)]
+
+ guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0)
+
+ eps = 1e-6
+ norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps)
+ norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps)
+
+ ratio = norm_guided / norm_pos
+ scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio
+
+ guided_normalized = guided * scale_factor
+
+ z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha)
+
+ if img_slice is not None:
+ orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final
+ orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final
+ return orig_out
+ else:
+ out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final
+ return out
+
+ m.set_model_attn1_output_patch(nag_attention_output_patch)
+ m.disable_model_cfg1_optimization()
+
+ return io.NodeOutput(m)
+
+
+class NagExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return [
+ NAGuidance,
+ ]
+
+
+async def comfy_entrypoint() -> NagExtension:
+ return NagExtension()
diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py
index 5d75148c0..5c143daab 100644
--- a/comfy_extras/nodes_post_processing.py
+++ b/comfy_extras/nodes_post_processing.py
@@ -77,6 +77,7 @@ class Blur(io.ComfyNode):
return io.Schema(
node_id="ImageBlur",
category="image/postprocessing",
+ essentials_category="Image Tools",
inputs=[
io.Image.Input("image"),
io.Int.Input("blur_radius", default=1, min=1, max=31, step=1),
@@ -655,6 +656,7 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
batched = batch_masks(values)
return io.NodeOutput(batched)
+
class PostProcessingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
diff --git a/comfy_extras/nodes_replacements.py b/comfy_extras/nodes_replacements.py
new file mode 100644
index 000000000..7684e854c
--- /dev/null
+++ b/comfy_extras/nodes_replacements.py
@@ -0,0 +1,103 @@
+from comfy_api.latest import ComfyExtension, io, ComfyAPI
+
+api = ComfyAPI()
+
+
+async def register_replacements():
+ """Register all built-in node replacements."""
+ await register_replacements_longeredge()
+ await register_replacements_batchimages()
+ await register_replacements_upscaleimage()
+ await register_replacements_controlnet()
+ await register_replacements_load3d()
+ await register_replacements_preview3d()
+ await register_replacements_svdimg2vid()
+ await register_replacements_conditioningavg()
+
+async def register_replacements_longeredge():
+ # No dynamic inputs here
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="ImageScaleToMaxDimension",
+ old_node_id="ResizeImagesByLongerEdge",
+ old_widget_ids=["longer_edge"],
+ input_mapping=[
+ {"new_id": "image", "old_id": "images"},
+ {"new_id": "largest_size", "old_id": "longer_edge"},
+ {"new_id": "upscale_method", "set_value": "lanczos"},
+ ],
+ # just to test the frontend output_mapping code, does nothing really here
+ output_mapping=[{"new_idx": 0, "old_idx": 0}],
+ ))
+
+async def register_replacements_batchimages():
+ # BatchImages node uses Autogrow
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="BatchImagesNode",
+ old_node_id="ImageBatch",
+ input_mapping=[
+ {"new_id": "images.image0", "old_id": "image1"},
+ {"new_id": "images.image1", "old_id": "image2"},
+ ],
+ ))
+
+async def register_replacements_upscaleimage():
+ # ResizeImageMaskNode uses DynamicCombo
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="ResizeImageMaskNode",
+ old_node_id="ImageScaleBy",
+ old_widget_ids=["upscale_method", "scale_by"],
+ input_mapping=[
+ {"new_id": "input", "old_id": "image"},
+ {"new_id": "resize_type", "set_value": "scale by multiplier"},
+ {"new_id": "resize_type.multiplier", "old_id": "scale_by"},
+ {"new_id": "scale_method", "old_id": "upscale_method"},
+ ],
+ ))
+
+async def register_replacements_controlnet():
+ # T2IAdapterLoader → ControlNetLoader
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="ControlNetLoader",
+ old_node_id="T2IAdapterLoader",
+ input_mapping=[
+ {"new_id": "control_net_name", "old_id": "t2i_adapter_name"},
+ ],
+ ))
+
+async def register_replacements_load3d():
+ # Load3DAnimation merged into Load3D
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="Load3D",
+ old_node_id="Load3DAnimation",
+ ))
+
+async def register_replacements_preview3d():
+ # Preview3DAnimation merged into Preview3D
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="Preview3D",
+ old_node_id="Preview3DAnimation",
+ ))
+
+async def register_replacements_svdimg2vid():
+ # Typo fix: SDV → SVD
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="SVD_img2vid_Conditioning",
+ old_node_id="SDV_img2vid_Conditioning",
+ ))
+
+async def register_replacements_conditioningavg():
+ # Typo fix: trailing space in node name
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="ConditioningAverage",
+ old_node_id="ConditioningAverage ",
+ ))
+
+class NodeReplacementsExtension(ComfyExtension):
+ async def on_load(self) -> None:
+ await register_replacements()
+
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return []
+
+async def comfy_entrypoint() -> NodeReplacementsExtension:
+ return NodeReplacementsExtension()
diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py
new file mode 100644
index 000000000..dd4f6b0d3
--- /dev/null
+++ b/comfy_extras/nodes_textgen.py
@@ -0,0 +1,176 @@
+from comfy_api.latest import ComfyExtension, io
+from typing_extensions import override
+
+class TextGenerate(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ # Define dynamic combo options for sampling mode
+ sampling_options = [
+ io.DynamicCombo.Option(
+ key="on",
+ inputs=[
+ io.Float.Input("temperature", default=0.7, min=0.01, max=2.0, step=0.000001),
+ io.Int.Input("top_k", default=64, min=0, max=1000),
+ io.Float.Input("top_p", default=0.95, min=0.0, max=1.0, step=0.01),
+ io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01),
+ io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01),
+ io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
+ ]
+ ),
+ io.DynamicCombo.Option(
+ key="off",
+ inputs=[]
+ ),
+ ]
+
+ return io.Schema(
+ node_id="TextGenerate",
+ category="textgen/",
+ search_aliases=["LLM", "gemma"],
+ inputs=[
+ io.Clip.Input("clip"),
+ io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
+ io.Image.Input("image", optional=True),
+ io.Int.Input("max_length", default=256, min=1, max=2048),
+ io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
+ ],
+ outputs=[
+ io.String.Output(display_name="generated_text"),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
+
+ tokens = clip.tokenize(prompt, image=image, skip_template=False)
+
+ # Get sampling parameters from dynamic combo
+ do_sample = sampling_mode.get("sampling_mode") == "on"
+ temperature = sampling_mode.get("temperature", 1.0)
+ top_k = sampling_mode.get("top_k", 50)
+ top_p = sampling_mode.get("top_p", 1.0)
+ min_p = sampling_mode.get("min_p", 0.0)
+ seed = sampling_mode.get("seed", None)
+ repetition_penalty = sampling_mode.get("repetition_penalty", 1.0)
+
+ generated_ids = clip.generate(
+ tokens,
+ do_sample=do_sample,
+ max_length=max_length,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ min_p=min_p,
+ repetition_penalty=repetition_penalty,
+ seed=seed
+ )
+
+ generated_text = clip.decode(generated_ids, skip_special_tokens=True)
+ return io.NodeOutput(generated_text)
+
+
+LTX2_T2V_SYSTEM_PROMPT = """You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
+#### Guidelines
+- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, actions, camera movement, audio).
+ - If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
+ - For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
+- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
+- Maintain chronological flow: use temporal connectors ("as," "then," "while").
+- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., "ambient sound is present").
+- Speech (only when requested):
+ - For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
+ - Specify language if not English and accent if relevant.
+- Style: Include visual style at the beginning: "Style: