diff --git a/.github/workflows/release-webhook.yml b/.github/workflows/release-webhook.yml
index 6fceb7560..737e4c488 100644
--- a/.github/workflows/release-webhook.yml
+++ b/.github/workflows/release-webhook.yml
@@ -7,6 +7,8 @@ on:
jobs:
send-webhook:
runs-on: ubuntu-latest
+ env:
+ DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
steps:
- name: Send release webhook
env:
@@ -106,3 +108,37 @@ jobs:
--fail --silent --show-error
echo "✅ Release webhook sent successfully"
+
+ - name: Send repository dispatch to desktop
+ env:
+ DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
+ RELEASE_TAG: ${{ github.event.release.tag_name }}
+ RELEASE_URL: ${{ github.event.release.html_url }}
+ run: |
+ set -euo pipefail
+
+ if [ -z "${DISPATCH_TOKEN:-}" ]; then
+ echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
+ exit 1
+ fi
+
+ PAYLOAD="$(jq -n \
+ --arg release_tag "$RELEASE_TAG" \
+ --arg release_url "$RELEASE_URL" \
+ '{
+ event_type: "comfyui_release_published",
+ client_payload: {
+ release_tag: $release_tag,
+ release_url: $release_url
+ }
+ }')"
+
+ curl -fsSL \
+ -X POST \
+ -H "Accept: application/vnd.github+json" \
+ -H "Content-Type: application/json" \
+ -H "Authorization: Bearer ${DISPATCH_TOKEN}" \
+ https://api.github.com/repos/Comfy-Org/desktop/dispatches \
+ -d "$PAYLOAD"
+
+ echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"
diff --git a/README.md b/README.md
index 96dc2904b..3ccdc9c19 100644
--- a/README.md
+++ b/README.md
@@ -227,7 +227,7 @@ Put your VAE in: models/vae
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
-```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
+```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
diff --git a/app/node_replace_manager.py b/app/node_replace_manager.py
new file mode 100644
index 000000000..03b603c70
--- /dev/null
+++ b/app/node_replace_manager.py
@@ -0,0 +1,105 @@
+from __future__ import annotations
+
+from aiohttp import web
+
+from typing import TYPE_CHECKING, TypedDict
+if TYPE_CHECKING:
+ from comfy_api.latest._io_public import NodeReplace
+
+from comfy_execution.graph_utils import is_link
+import nodes
+
+class NodeStruct(TypedDict):
+ inputs: dict[str, str | int | float | bool | tuple[str, int]]
+ class_type: str
+ _meta: dict[str, str]
+
+def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
+ new_node_struct = node_struct.copy()
+ if empty_inputs:
+ new_node_struct["inputs"] = {}
+ else:
+ new_node_struct["inputs"] = node_struct["inputs"].copy()
+ new_node_struct["_meta"] = node_struct["_meta"].copy()
+ return new_node_struct
+
+
+class NodeReplaceManager:
+ """Manages node replacement registrations."""
+
+ def __init__(self):
+ self._replacements: dict[str, list[NodeReplace]] = {}
+
+ def register(self, node_replace: NodeReplace):
+ """Register a node replacement mapping."""
+ self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
+
+ def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
+ """Get replacements for an old node ID."""
+ return self._replacements.get(old_node_id)
+
+ def has_replacement(self, old_node_id: str) -> bool:
+ """Check if a replacement exists for an old node ID."""
+ return old_node_id in self._replacements
+
+ def apply_replacements(self, prompt: dict[str, NodeStruct]):
+ connections: dict[str, list[tuple[str, str, int]]] = {}
+ need_replacement: set[str] = set()
+ for node_number, node_struct in prompt.items():
+ class_type = node_struct["class_type"]
+ # need replacement if not in NODE_CLASS_MAPPINGS and has replacement
+ if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
+ need_replacement.add(node_number)
+ # keep track of connections
+ for input_id, input_value in node_struct["inputs"].items():
+ if is_link(input_value):
+ conn_number = input_value[0]
+ connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
+ for node_number in need_replacement:
+ node_struct = prompt[node_number]
+ class_type = node_struct["class_type"]
+ replacements = self.get_replacement(class_type)
+ if replacements is None:
+ continue
+ # just use the first replacement
+ replacement = replacements[0]
+ new_node_id = replacement.new_node_id
+ # if replacement is not a valid node, skip trying to replace it as will only cause confusion
+ if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
+ continue
+ # first, replace node id (class_type)
+ new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
+ new_node_struct["class_type"] = new_node_id
+ # TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
+ # second, replace inputs
+ if replacement.input_mapping is not None:
+ for input_map in replacement.input_mapping:
+ if "set_value" in input_map:
+ new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
+ elif "old_id" in input_map:
+ new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
+ # finalize input replacement
+ prompt[node_number] = new_node_struct
+ # third, replace outputs
+ if replacement.output_mapping is not None:
+ # re-mapping outputs requires changing the input values of nodes that receive connections from this one
+ if node_number in connections:
+ for conns in connections[node_number]:
+ conn_node_number, conn_input_id, old_output_idx = conns
+ for output_map in replacement.output_mapping:
+ if output_map["old_idx"] == old_output_idx:
+ new_output_idx = output_map["new_idx"]
+ previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
+ previous_input[1] = new_output_idx
+
+ def as_dict(self):
+ """Serialize all replacements to dict."""
+ return {
+ k: [v.as_dict() for v in v_list]
+ for k, v_list in self._replacements.items()
+ }
+
+ def add_routes(self, routes):
+ @routes.get("/node_replacements")
+ async def get_node_replacements(request):
+ return web.json_response(self.as_dict())
diff --git a/comfy/checkpoint_pickle.py b/comfy/checkpoint_pickle.py
deleted file mode 100644
index 206551d3c..000000000
--- a/comfy/checkpoint_pickle.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import pickle
-
-load = pickle.load
-
-class Empty:
- pass
-
-class Unpickler(pickle.Unpickler):
- def find_class(self, module, name):
- #TODO: safe unpickle
- if module.startswith("pytorch_lightning"):
- return Empty
- return super().find_class(module, name)
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index 9e1e704e0..ba670b16d 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -297,6 +297,30 @@ class ControlNet(ControlBase):
self.model_sampling_current = None
super().cleanup()
+
+class QwenFunControlNet(ControlNet):
+ def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
+ # Fun checkpoints are more sensitive to high strengths in the generic
+ # ControlNet merge path. Use a soft response curve so strength=1.0 stays
+ # unchanged while >1 grows more gently.
+ original_strength = self.strength
+ self.strength = math.sqrt(max(self.strength, 0.0))
+ try:
+ return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
+ finally:
+ self.strength = original_strength
+
+ def pre_run(self, model, percent_to_timestep_function):
+ super().pre_run(model, percent_to_timestep_function)
+ self.set_extra_arg("base_model", model.diffusion_model)
+
+ def copy(self):
+ c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
+ c.control_model = self.control_model
+ c.control_model_wrapped = self.control_model_wrapped
+ self.copy_to(c)
+ return c
+
class ControlLoraOps:
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
@@ -560,6 +584,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
+ sd = model_config.process_unet_state_dict(sd)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
@@ -605,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
+
+def load_controlnet_qwen_fun(sd, model_options={}):
+ load_device = comfy.model_management.get_torch_device()
+ weight_dtype = comfy.utils.weight_dtype(sd)
+ unet_dtype = model_options.get("dtype", weight_dtype)
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
+
+ operations = model_options.get("custom_operations", None)
+ if operations is None:
+ operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
+
+ in_features = sd["control_img_in.weight"].shape[1]
+ inner_dim = sd["control_img_in.weight"].shape[0]
+
+ block_weight = sd["control_blocks.0.attn.to_q.weight"]
+ attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
+ num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
+
+ model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
+ control_in_features=in_features,
+ inner_dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ num_control_blocks=5,
+ main_model_double=60,
+ injection_layers=(0, 12, 24, 36, 48),
+ operations=operations,
+ device=comfy.model_management.unet_offload_device(),
+ dtype=unet_dtype,
+ )
+ model = controlnet_load_state_dict(model, sd)
+
+ latent_format = comfy.latent_formats.Wan21()
+ control = QwenFunControlNet(
+ model,
+ compression_ratio=1,
+ latent_format=latent_format,
+ # Fun checkpoints already expect their own 33-channel context handling.
+ # Enabling generic concat_mask injects an extra mask channel at apply-time
+ # and breaks the intended fallback packing path.
+ concat_mask=False,
+ load_device=load_device,
+ manual_cast_dtype=manual_cast_dtype,
+ extra_conds=[],
+ )
+ return control
+
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
@@ -682,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
+ elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
+ return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py
index c0c51d51a..6978eb717 100644
--- a/comfy/k_diffusion/sampling.py
+++ b/comfy/k_diffusion/sampling.py
@@ -1,12 +1,11 @@
import math
-import time
from functools import partial
from scipy import integrate
import torch
from torch import nn
import torchsde
-from tqdm.auto import trange as trange_, tqdm
+from tqdm.auto import tqdm
from . import utils
from . import deis
@@ -15,34 +14,7 @@ import comfy.model_patcher
import comfy.model_sampling
import comfy.memory_management
-
-
-def trange(*args, **kwargs):
- if comfy.memory_management.aimdo_allocator is None:
- return trange_(*args, **kwargs)
-
- pbar = trange_(*args, **kwargs, smoothing=1.0)
- pbar._i = 0
- pbar.set_postfix_str(" Model Initializing ... ")
-
- _update = pbar.update
-
- def warmup_update(n=1):
- pbar._i += 1
- if pbar._i == 1:
- pbar.i1_time = time.time()
- pbar.set_postfix_str(" Model Initialization complete! ")
- elif pbar._i == 2:
- #bring forward the effective start time based the the diff between first and second iteration
- #to attempt to remove load overhead from the final step rate estimate.
- pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
- pbar.set_postfix_str("")
-
- _update(n)
-
- pbar.update = warmup_update
- return pbar
-
+from comfy.utils import model_trange as trange
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
diff --git a/comfy/ldm/anima/model.py b/comfy/ldm/anima/model.py
index 2e6ed58fa..6fb51c4a4 100644
--- a/comfy/ldm/anima/model.py
+++ b/comfy/ldm/anima/model.py
@@ -195,8 +195,20 @@ class Anima(MiniTrainDIT):
super().__init__(*args, **kwargs)
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
- def preprocess_text_embeds(self, text_embeds, text_ids):
+ def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
if text_ids is not None:
- return self.llm_adapter(text_embeds, text_ids)
+ out = self.llm_adapter(text_embeds, text_ids)
+ if t5xxl_weights is not None:
+ out = out * t5xxl_weights
+
+ if out.shape[1] < 512:
+ out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
+ return out
else:
return text_embeds
+
+ def forward(self, x, timesteps, context, **kwargs):
+ t5xxl_ids = kwargs.pop("t5xxl_ids", None)
+ if t5xxl_ids is not None:
+ context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
+ return super().forward(x, timesteps, context, **kwargs)
diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py
index 2d5684348..df348a8ed 100644
--- a/comfy/ldm/chroma/layers.py
+++ b/comfy/ldm/chroma/layers.py
@@ -3,7 +3,6 @@ from torch import Tensor, nn
from comfy.ldm.flux.layers import (
MLPEmbedder,
- RMSNorm,
ModulationOut,
)
@@ -29,7 +28,7 @@ class Approximator(nn.Module):
super().__init__()
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
- self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
+ self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
@property
diff --git a/comfy/ldm/chroma_radiance/layers.py b/comfy/ldm/chroma_radiance/layers.py
index 3c7bc9b6b..08d31e0ba 100644
--- a/comfy/ldm/chroma_radiance/layers.py
+++ b/comfy/ldm/chroma_radiance/layers.py
@@ -4,8 +4,6 @@ from functools import lru_cache
import torch
from torch import nn
-from comfy.ldm.flux.layers import RMSNorm
-
class NerfEmbedder(nn.Module):
"""
@@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
# We now need to generate parameters for 3 matrices.
total_params = 3 * hidden_size_x**2 * mlp_ratio
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
- self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
+ self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
self.mlp_ratio = mlp_ratio
@@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
class NerfFinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
- self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
+ self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
class NerfFinalLayerConv(nn.Module):
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
- self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
+ self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.conv = operations.Conv2d(
in_channels=hidden_size,
out_channels=out_channels,
diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py
index 60f2bdae2..3518a1922 100644
--- a/comfy/ldm/flux/layers.py
+++ b/comfy/ldm/flux/layers.py
@@ -5,9 +5,9 @@ import torch
from torch import Tensor, nn
from .math import attention, rope
-import comfy.ops
-import comfy.ldm.common_dit
+# Fix import for some custom nodes, TODO: delete eventually.
+RMSNorm = None
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list):
@@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
-class RMSNorm(torch.nn.Module):
- def __init__(self, dim: int, dtype=None, device=None, operations=None):
- super().__init__()
- self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
-
- def forward(self, x: Tensor):
- return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
-
class QKNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
- self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
- self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
+ self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
+ self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q)
@@ -169,7 +161,7 @@ class SiLUActivation(nn.Module):
class DoubleStreamBlock(nn.Module):
- def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module):
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
- self.flipped_img_txt = flipped_img_txt
-
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
if self.modulation:
img_mod1, img_mod2 = self.img_mod(vec)
@@ -224,32 +214,17 @@ class DoubleStreamBlock(nn.Module):
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
- if self.flipped_img_txt:
- q = torch.cat((img_q, txt_q), dim=2)
- del img_q, txt_q
- k = torch.cat((img_k, txt_k), dim=2)
- del img_k, txt_k
- v = torch.cat((img_v, txt_v), dim=2)
- del img_v, txt_v
- # run actual attention
- attn = attention(q, k, v,
- pe=pe, mask=attn_mask, transformer_options=transformer_options)
- del q, k, v
+ q = torch.cat((txt_q, img_q), dim=2)
+ del txt_q, img_q
+ k = torch.cat((txt_k, img_k), dim=2)
+ del txt_k, img_k
+ v = torch.cat((txt_v, img_v), dim=2)
+ del txt_v, img_v
+ # run actual attention
+ attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
+ del q, k, v
- img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
- else:
- q = torch.cat((txt_q, img_q), dim=2)
- del txt_q, img_q
- k = torch.cat((txt_k, img_k), dim=2)
- del txt_k, img_k
- v = torch.cat((txt_v, img_v), dim=2)
- del txt_v, img_v
- # run actual attention
- attn = attention(q, k, v,
- pe=pe, mask=attn_mask, transformer_options=transformer_options)
- del q, k, v
-
- txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py
index f9597de5b..5e764bb46 100644
--- a/comfy/ldm/flux/math.py
+++ b/comfy/ldm/flux/math.py
@@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
return out.to(dtype=torch.float32, device=pos.device)
+def _apply_rope1(x: Tensor, freqs_cis: Tensor):
+ x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
+
+ x_out = freqs_cis[..., 0] * x_[..., 0]
+ x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
+
+ return x_out.reshape(*x.shape).type_as(x)
+
+
+def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
+ return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
+
+
try:
import comfy.quant_ops
- apply_rope = comfy.quant_ops.ck.apply_rope
- apply_rope1 = comfy.quant_ops.ck.apply_rope1
+ q_apply_rope = comfy.quant_ops.ck.apply_rope
+ q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
+ def apply_rope(xq, xk, freqs_cis):
+ if comfy.model_management.in_training:
+ return _apply_rope(xq, xk, freqs_cis)
+ else:
+ return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
+ def apply_rope1(x, freqs_cis):
+ if comfy.model_management.in_training:
+ return _apply_rope1(x, freqs_cis)
+ else:
+ return q_apply_rope1(x, freqs_cis)
except:
logging.warning("No comfy kitchen, using old apply_rope functions.")
- def apply_rope1(x: Tensor, freqs_cis: Tensor):
- x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
-
- x_out = freqs_cis[..., 0] * x_[..., 0]
- x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
-
- return x_out.reshape(*x.shape).type_as(x)
-
- def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
- return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
+ apply_rope = _apply_rope
+ apply_rope1 = _apply_rope1
diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py
index f40c2a7a9..260ccad7e 100644
--- a/comfy/ldm/flux/model.py
+++ b/comfy/ldm/flux/model.py
@@ -16,7 +16,6 @@ from .layers import (
SingleStreamBlock,
timestep_embedding,
Modulation,
- RMSNorm
)
@dataclass
@@ -81,7 +80,7 @@ class Flux(nn.Module):
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
if params.txt_norm:
- self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
+ self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
else:
self.txt_norm = None
diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py
index 55ab550f8..563f28f6b 100644
--- a/comfy/ldm/hunyuan_video/model.py
+++ b/comfy/ldm/hunyuan_video/model.py
@@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
- flipped_img_txt=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -378,14 +377,14 @@ class HunyuanVideo(nn.Module):
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
- ids = torch.cat((img_ids, txt_ids), dim=1)
+ ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
img_len = img.shape[1]
if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
- attn_mask[:, 0, img_len:] = txt_mask
+ attn_mask[:, 0, :txt.shape[1]] = txt_mask
else:
attn_mask = None
@@ -413,7 +412,7 @@ class HunyuanVideo(nn.Module):
if add is not None:
img += add
- img = torch.cat((img, txt), 1)
+ img = torch.cat((txt, img), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
@@ -435,9 +434,9 @@ class HunyuanVideo(nn.Module):
if i < len(control_o):
add = control_o[i]
if add is not None:
- img[:, : img_len] += add
+ img[:, txt.shape[1]: img_len + txt.shape[1]] += add
- img = img[:, : img_len]
+ img = img[:, txt.shape[1]: img_len + txt.shape[1]]
if ref_latent is not None:
img = img[:, ref_latent.shape[1]:]
diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py
index 5a22ef030..805592aa5 100644
--- a/comfy/ldm/modules/diffusionmodules/model.py
+++ b/comfy/ldm/modules/diffusionmodules/model.py
@@ -102,19 +102,7 @@ class VideoConv3d(nn.Module):
return self.conv(x)
def interpolate_up(x, scale_factor):
- try:
- return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
- except: #operation not implemented for bf16
- orig_shape = list(x.shape)
- out_shape = orig_shape[:2]
- for i in range(len(orig_shape) - 2):
- out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
- out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
- split = 8
- l = out.shape[1] // split
- for i in range(0, out.shape[1], l):
- out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
- return out
+ return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py
index a6d408104..c0aae9240 100644
--- a/comfy/ldm/qwen_image/controlnet.py
+++ b/comfy/ldm/qwen_image/controlnet.py
@@ -2,6 +2,196 @@ import torch
import math
from .model import QwenImageTransformer2DModel
+from .model import QwenImageTransformerBlock
+
+
+class QwenImageFunControlBlock(QwenImageTransformerBlock):
+ def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
+ super().__init__(
+ dim=dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+ self.has_before_proj = has_before_proj
+ if has_before_proj:
+ self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
+ self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
+
+
+class QwenImageFunControlNetModel(torch.nn.Module):
+ def __init__(
+ self,
+ control_in_features=132,
+ inner_dim=3072,
+ num_attention_heads=24,
+ attention_head_dim=128,
+ num_control_blocks=5,
+ main_model_double=60,
+ injection_layers=(0, 12, 24, 36, 48),
+ dtype=None,
+ device=None,
+ operations=None,
+ ):
+ super().__init__()
+ self.dtype = dtype
+ self.main_model_double = main_model_double
+ self.injection_layers = tuple(injection_layers)
+ # Keep base hint scaling at 1.0 so user-facing strength behaves similarly
+ # to the reference Gen2/VideoX implementation around strength=1.
+ self.hint_scale = 1.0
+ self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
+
+ self.control_blocks = torch.nn.ModuleList([])
+ for i in range(num_control_blocks):
+ self.control_blocks.append(
+ QwenImageFunControlBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ has_before_proj=(i == 0),
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+ )
+
+ def _process_hint_tokens(self, hint):
+ if hint is None:
+ return None
+ if hint.ndim == 4:
+ hint = hint.unsqueeze(2)
+
+ # Fun checkpoints are trained with 33 latent channels before 2x2 packing:
+ # [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
+ # Default behavior (no inpaint input in stock Apply ControlNet) should use
+ # zeros for mask/inpaint branches, matching VideoX fallback semantics.
+ expected_c = self.control_img_in.weight.shape[1] // 4
+ if hint.shape[1] == 16 and expected_c == 33:
+ zeros_mask = torch.zeros_like(hint[:, :1])
+ zeros_inpaint = torch.zeros_like(hint)
+ hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
+
+ bs, c, t, h, w = hint.shape
+ hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
+ orig_shape = hidden_states.shape
+ hidden_states = hidden_states.view(
+ orig_shape[0],
+ orig_shape[1],
+ orig_shape[-3],
+ orig_shape[-2] // 2,
+ 2,
+ orig_shape[-1] // 2,
+ 2,
+ )
+ hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
+ hidden_states = hidden_states.reshape(
+ bs,
+ t * ((h + 1) // 2) * ((w + 1) // 2),
+ c * 4,
+ )
+
+ expected_in = self.control_img_in.weight.shape[1]
+ cur_in = hidden_states.shape[-1]
+ if cur_in < expected_in:
+ pad = torch.zeros(
+ (hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
+ device=hidden_states.device,
+ dtype=hidden_states.dtype,
+ )
+ hidden_states = torch.cat([hidden_states, pad], dim=-1)
+ elif cur_in > expected_in:
+ hidden_states = hidden_states[:, :, :expected_in]
+
+ return hidden_states
+
+ def forward(
+ self,
+ x,
+ timesteps,
+ context,
+ attention_mask=None,
+ guidance: torch.Tensor = None,
+ hint=None,
+ transformer_options={},
+ base_model=None,
+ **kwargs,
+ ):
+ if base_model is None:
+ raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
+
+ encoder_hidden_states_mask = attention_mask
+ # Keep attention mask disabled inside Fun control blocks to mirror
+ # VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
+ encoder_hidden_states_mask = None
+
+ hidden_states, img_ids, _ = base_model.process_img(x)
+ hint_tokens = self._process_hint_tokens(hint)
+ if hint_tokens is None:
+ raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
+
+ if hint_tokens.shape[1] != hidden_states.shape[1]:
+ max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
+ hint_tokens = hint_tokens[:, :max_tokens]
+ hidden_states = hidden_states[:, :max_tokens]
+ img_ids = img_ids[:, :max_tokens]
+
+ txt_start = round(
+ max(
+ ((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
+ ((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
+ )
+ )
+ txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
+ ids = torch.cat((txt_ids, img_ids), dim=1)
+ image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
+
+ hidden_states = base_model.img_in(hidden_states)
+ encoder_hidden_states = base_model.txt_norm(context)
+ encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
+
+ if guidance is not None:
+ guidance = guidance * 1000
+
+ temb = (
+ base_model.time_text_embed(timesteps, hidden_states)
+ if guidance is None
+ else base_model.time_text_embed(timesteps, guidance, hidden_states)
+ )
+
+ c = self.control_img_in(hint_tokens)
+
+ for i, block in enumerate(self.control_blocks):
+ if i == 0:
+ c_in = block.before_proj(c) + hidden_states
+ all_c = []
+ else:
+ all_c = list(torch.unbind(c, dim=0))
+ c_in = all_c.pop(-1)
+
+ encoder_hidden_states, c_out = block(
+ hidden_states=c_in,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ transformer_options=transformer_options,
+ )
+
+ c_skip = block.after_proj(c_out) * self.hint_scale
+ all_c += [c_skip, c_out]
+ c = torch.stack(all_c, dim=0)
+
+ hints = torch.unbind(c, dim=0)[:-1]
+
+ controlnet_block_samples = [None] * self.main_model_double
+ for local_idx, base_idx in enumerate(self.injection_layers):
+ if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
+ controlnet_block_samples[base_idx] = hints[local_idx]
+
+ return {"input": controlnet_block_samples}
class QwenImageControlNetModel(QwenImageTransformer2DModel):
diff --git a/comfy/lora.py b/comfy/lora.py
index 44030bcab..279cf38bb 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -374,6 +374,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
return padded_tensor
+def calculate_shape(patches, weight, key, original_weights=None):
+ current_shape = weight.shape
+
+ for p in patches:
+ v = p[1]
+ offset = p[3]
+
+ # Offsets restore the old shape; lists force a diff without metadata
+ if offset is not None or isinstance(v, list):
+ continue
+
+ if isinstance(v, weight_adapter.WeightAdapterBase):
+ adapter_shape = v.calculate_shape(key)
+ if adapter_shape is not None:
+ current_shape = adapter_shape
+ continue
+
+ # Standard diff logic with padding
+ if len(v) == 2:
+ patch_type, patch_data = v[0], v[1]
+ if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
+ current_shape = patch_data[0].shape
+
+ return current_shape
+
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
for p in patches:
strength = p[0]
diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py
index 9d8d21efe..749e81df3 100644
--- a/comfy/lora_convert.py
+++ b/comfy/lora_convert.py
@@ -5,7 +5,7 @@ import comfy.utils
def convert_lora_bfl_control(sd): #BFL loras for Flux
sd_out = {}
for k in sd:
- k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
+ k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
sd_out[k_to] = sd[k]
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 858789b30..4a74cb1ce 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -1160,12 +1160,16 @@ class Anima(BaseModel):
device = kwargs["device"]
if cross_attn is not None:
if t5xxl_ids is not None:
- cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
if t5xxl_weights is not None:
- cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
+ t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
+ t5xxl_ids = t5xxl_ids.unsqueeze(0)
+
+ if torch.is_inference_mode_enabled(): # if not we are training
+ cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
+ else:
+ out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
+ out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
- if cross_attn.shape[1] < 512:
- cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index e8ad725df..30ea03e8e 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count
+def any_suffix_in(keys, prefix, main, suffix_list=[]):
+ for x in suffix_list:
+ if "{}{}{}".format(prefix, main, x) in keys:
+ return True
+ return False
+
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None
use_linear_in_transformer = False
@@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["meanflow_sum"] = False
return dit_config
- if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
+ if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {}
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "flux2"
@@ -241,7 +247,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
- if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
+
+ if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64
dit_config["out_channels"] = 64
@@ -249,7 +256,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5
- if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
+
+ if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3
dit_config["out_channels"] = 3
@@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 512
- dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
+ dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
@@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
- dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
+ dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
dit_config["txt_ids_dims"] = [1, 2]
diff --git a/comfy/model_management.py b/comfy/model_management.py
index b6291f340..38c3e482b 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -19,7 +19,7 @@
import psutil
import logging
from enum import Enum
-from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
+from comfy.cli_args import args, PerformanceFeature
import threading
import torch
import sys
@@ -55,6 +55,11 @@ cpu_state = CPUState.GPU
total_vram = 0
+
+# Training Related State
+in_training = False
+
+
def get_supported_float8_types():
float8_types = []
try:
@@ -651,7 +656,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
soft_empty_cache()
return unloaded_models
-def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
+def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
@@ -747,26 +752,6 @@ def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, m
current_loaded_models.insert(0, loaded_model)
return
-def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
- with torch.inference_mode():
- load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
- soft_empty_cache()
-
-def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
- #Deliberately load models outside of the Aimdo mempool so they can be retained accross
- #nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
- #thread local. So exploit that to escape context
- if enables_dynamic_vram():
- t = threading.Thread(
- target=load_models_gpu_thread,
- args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
- )
- t.start()
- t.join()
- else:
- load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
- minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
-
def load_model_gpu(model):
return load_models_gpu([model])
@@ -1226,21 +1211,20 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if dtype is None:
dtype = weight._model_dtype
- r = torch.empty_like(weight, dtype=dtype, device=device)
-
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
- raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
- v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
- if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
+ if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
+ v_tensor = weight._v_tensor
+ else:
+ raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
+ v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
+ weight._v_tensor = v_tensor
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)
- #always take a deep copy even if _v is good, as we have no reasonable point to unpin
- #a non comfy weight
- r.copy_(v_tensor)
- comfy_aimdo.model_vbar.vbar_unpin(weight._v)
- return r
+ return v_tensor.to(dtype=dtype)
+
+ r = torch.empty_like(weight, dtype=dtype, device=device)
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
#Offloaded casting could skip this, however it would make the quantizations
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index d888dbcfb..f01818f50 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -19,7 +19,6 @@
from __future__ import annotations
import collections
-import copy
import inspect
import logging
import math
@@ -317,7 +316,7 @@ class ModelPatcher:
n.object_patches = self.object_patches.copy()
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
- n.model_options = copy.deepcopy(self.model_options)
+ n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
@@ -680,18 +679,19 @@ class ModelPatcher:
for key in list(self.pinned):
self.unpin_weight(key)
- def _load_list(self, prio_comfy_cast_weights=False):
+ def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
loading = []
for n, m in self.model.named_modules():
- params = []
- skip = False
- for name, param in m.named_parameters(recurse=False):
- params.append(name)
+ default = False
+ params = { name: param for name, param in m.named_parameters(recurse=False) }
for name, param in m.named_parameters(recurse=True):
if name not in params:
- skip = True # skip random weights in non leaf modules
+ default = True # default random weights in non leaf modules
break
- if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
+ if default and default_device is not None:
+ for param in params.values():
+ param.data = param.data.to(device=default_device)
+ if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
@@ -1492,9 +1492,11 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None:
vbar.prioritize()
- #We have way more tools for acceleration on comfy weight offloading, so always
+ #We force reserve VRAM for the non comfy-weight so we dont have to deal
+ #with pin and unpin syncrhonization which can be expensive for small weights
+ #with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse).
- loading = self._load_list(prio_comfy_cast_weights=True)
+ loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
loading.sort(reverse=True)
for x in loading:
@@ -1512,8 +1514,10 @@ class ModelPatcherDynamic(ModelPatcher):
weight, _, _ = get_key_weight(self.model, key)
if weight is None:
- return 0
+ return (False, 0)
if key in self.patches:
+ if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
+ return (True, 0)
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
num_patches += 1
else:
@@ -1524,10 +1528,16 @@ class ModelPatcherDynamic(ModelPatcher):
setattr(m, param_key + "_function", weight_function)
geometry = weight
if not isinstance(weight, QuantizedTensor):
- model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
+ model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
- return comfy.memory_management.vram_aligned_size(geometry)
+ return (False, comfy.memory_management.vram_aligned_size(geometry))
+
+ def force_load_param(self, param_key, device_to):
+ key = key_param_name_to_key(n, param_key)
+ if key in self.backup:
+ comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
+ self.patch_weight_to_device(key, device_to=device_to)
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
@@ -1535,13 +1545,19 @@ class ModelPatcherDynamic(ModelPatcher):
m.seed_key = n
set_dirty(m, dirty)
- v_weight_size = 0
- v_weight_size += setup_param(self, m, n, "weight")
- v_weight_size += setup_param(self, m, n, "bias")
+ force_load, v_weight_size = setup_param(self, m, n, "weight")
+ force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
+ force_load = force_load or force_load_bias
+ v_weight_size += v_weight_bias
- if vbar is not None and not hasattr(m, "_v"):
- m._v = vbar.alloc(v_weight_size)
- allocated_size += v_weight_size
+ if force_load:
+ logging.info(f"Module {n} has resizing Lora - force loading")
+ force_load_param(self, "weight", device_to)
+ force_load_param(self, "bias", device_to)
+ else:
+ if vbar is not None and not hasattr(m, "_v"):
+ m._v = vbar.alloc(v_weight_size)
+ allocated_size += v_weight_size
else:
for param in params:
@@ -1550,13 +1566,16 @@ class ModelPatcherDynamic(ModelPatcher):
weight.seed_key = key
set_dirty(weight, dirty)
geometry = weight
- model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
+ model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
weight._model_dtype = model_dtype
allocated_size += weight_size
+ vbar.set_watermark_limit(allocated_size)
+
+ move_weight_functions(m, device_to)
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
@@ -1577,7 +1596,7 @@ class ModelPatcherDynamic(ModelPatcher):
return 0 if vbar is None else vbar.free_memory(memory_to_free)
def partially_unload_ram(self, ram_to_unload):
- loading = self._load_list(prio_comfy_cast_weights=True)
+ loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
for x in loading:
_, _, _, _, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
@@ -1598,6 +1617,13 @@ class ModelPatcherDynamic(ModelPatcher):
if unpatch_weights:
self.partially_unload_ram(1e32)
self.partially_unload(None, 1e32)
+ for m in self.model.modules():
+ move_weight_functions(m, device_to)
+
+ keys = list(self.backup.keys())
+ for k in keys:
+ bk = self.backup[k]
+ comfy.utils.set_attr_param(self.model, k, bk.weight)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above
diff --git a/comfy/ops.py b/comfy/ops.py
index 0f4eca7c7..688937e43 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -83,14 +83,18 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
offload_stream = None
xfer_dest = None
- cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
- if signature is not None:
- xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
+ if signature is not None:
+ if resident:
+ weight = s._v_weight
+ bias = s._v_bias
+ else:
+ xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
if not resident:
+ cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None
xfer_source = [ s.weight, s.bias ]
@@ -140,9 +144,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
- params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
- weight = params[0]
- bias = params[1]
+ params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
+ weight = params[0]
+ bias = params[1]
+ if signature is not None:
+ s._v_weight = weight
+ s._v_bias = bias
+ s._v_signature=signature
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
@@ -169,8 +177,8 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if orig.dtype == dtype and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
- else:
- y = x
+ elif update_weight:
+ y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
if update_weight:
orig.copy_(y)
for f in fns:
@@ -182,7 +190,6 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
- s._v_signature=signature
#FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None)
diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py
index 9134e6d71..1f75f2ba7 100644
--- a/comfy/sampler_helpers.py
+++ b/comfy/sampler_helpers.py
@@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required
-def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
+def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
- return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
+ return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
-def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
+def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
- memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
- comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
+ if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
+ memory_required = 1e20
+ minimum_memory_required = None
+ else:
+ memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
+ memory_required += inference_memory
+ minimum_memory_required += inference_memory
+ comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
real_model = model.model
return real_model, conds, models
diff --git a/comfy/sd.py b/comfy/sd.py
index bc9407405..f65e7cadd 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -793,8 +793,6 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
- model_management.archive_model_dtypes(self.first_stage_model)
-
if device is None:
device = model_management.vae_device()
self.device = device
@@ -803,6 +801,7 @@ class VAE:
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype)
+ model_management.archive_model_dtypes(self.first_stage_model)
self.output_device = model_management.intermediate_device()
mp = comfy.model_patcher.CoreModelPatcher
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index 4c817d468..b564d1529 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -171,8 +171,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None)
+ pad_token = self.special_tokens.get("pad", -1)
if end_token is None:
- cmp_token = self.special_tokens.get("pad", -1)
+ cmp_token = pad_token
else:
cmp_token = end_token
@@ -186,15 +187,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
other_embeds = []
eos = False
index = 0
+ left_pad = False
for y in x:
if isinstance(y, numbers.Integral):
- if eos:
+ token = int(y)
+ if index == 0 and token == pad_token:
+ left_pad = True
+
+ if eos or (left_pad and token == pad_token):
attention_mask.append(0)
else:
attention_mask.append(1)
- token = int(y)
+ left_pad = False
+
tokens_temp += [token]
- if not eos and token == cmp_token:
+ if not eos and token == cmp_token and not left_pad:
if end_token is None:
attention_mask[-1] = 0
eos = True
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index d33db7507..c28be1716 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -710,6 +710,15 @@ class Flux(supported_models_base.BASE):
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
+ def process_unet_state_dict(self, state_dict):
+ out_sd = {}
+ for k in list(state_dict.keys()):
+ key_out = k
+ if key_out.endswith("_norm.scale"):
+ key_out = "{}.weight".format(key_out[:-len(".scale")])
+ out_sd[key_out] = state_dict[k]
+ return out_sd
+
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
@@ -898,11 +907,13 @@ class HunyuanVideo(supported_models_base.BASE):
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
- key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
- key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
+ key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
+ key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
+ if key_out.endswith(".scale"):
+ key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
@@ -1264,6 +1275,15 @@ class Hunyuan3Dv2(supported_models_base.BASE):
latent_format = latent_formats.Hunyuan3Dv2
+ def process_unet_state_dict(self, state_dict):
+ out_sd = {}
+ for k in list(state_dict.keys()):
+ key_out = k
+ if key_out.endswith(".scale"):
+ key_out = "{}.weight".format(key_out[:-len(".scale")])
+ out_sd[key_out] = state_dict[k]
+ return out_sd
+
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
@@ -1341,6 +1361,14 @@ class Chroma(supported_models_base.BASE):
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
+ def process_unet_state_dict(self, state_dict):
+ out_sd = {}
+ for k in list(state_dict.keys()):
+ key_out = k
+ if key_out.endswith(".scale"):
+ key_out = "{}.weight".format(key_out[:-len(".scale")])
+ out_sd[key_out] = state_dict[k]
+ return out_sd
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Chroma(self, device=device)
diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py
index 00dd5ba90..f135d74c1 100644
--- a/comfy/text_encoders/ace15.py
+++ b/comfy/text_encoders/ace15.py
@@ -10,12 +10,12 @@ import comfy.utils
def sample_manual_loop_no_classes(
model,
ids=None,
- paddings=[],
execution_dtype=None,
cfg_scale: float = 2.0,
temperature: float = 0.85,
top_p: float = 0.9,
top_k: int = None,
+ min_p: float = 0.000,
seed: int = 1,
min_tokens: int = 1,
max_new_tokens: int = 2048,
@@ -23,6 +23,8 @@ def sample_manual_loop_no_classes(
audio_end_id: int = 215669,
eos_token_id: int = 151645,
):
+ if ids is None:
+ return []
device = model.execution_device
if execution_dtype is None:
@@ -32,31 +34,34 @@ def sample_manual_loop_no_classes(
execution_dtype = torch.float32
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
- for i, t in enumerate(paddings):
- attention_mask[i, :t] = 0
- attention_mask[i, t:] = 1
+ embeds_batch = embeds.shape[0]
output_audio_codes = []
past_key_values = []
generator = torch.Generator(device=device)
generator.manual_seed(seed)
model_config = model.transformer.model.config
+ past_kv_shape = [embeds_batch, model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim]
for x in range(model_config.num_hidden_layers):
- past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
+ past_key_values.append((torch.empty(past_kv_shape, device=device, dtype=execution_dtype), torch.empty(past_kv_shape, device=device, dtype=execution_dtype), 0))
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
- for step in range(max_new_tokens):
+ for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
past_key_values = outputs[2]
- cond_logits = next_token_logits[0:1]
- uncond_logits = next_token_logits[1:2]
- cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
+ if cfg_scale != 1.0:
+ cond_logits = next_token_logits[0:1]
+ uncond_logits = next_token_logits[1:2]
+ cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
+ else:
+ cfg_logits = next_token_logits[0:1]
- if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
+ use_eos_score = eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step
+ if use_eos_score:
eos_score = cfg_logits[:, eos_token_id].clone()
remove_logit_value = torch.finfo(cfg_logits.dtype).min
@@ -64,7 +69,7 @@ def sample_manual_loop_no_classes(
cfg_logits[:, :audio_start_id] = remove_logit_value
cfg_logits[:, audio_end_id:] = remove_logit_value
- if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
+ if use_eos_score:
cfg_logits[:, eos_token_id] = eos_score
if top_k is not None and top_k > 0:
@@ -72,6 +77,12 @@ def sample_manual_loop_no_classes(
min_val = top_k_vals[..., -1, None]
cfg_logits[cfg_logits < min_val] = remove_logit_value
+ if min_p is not None and min_p > 0:
+ probs = torch.softmax(cfg_logits, dim=-1)
+ p_max = probs.max(dim=-1, keepdim=True).values
+ indices_to_remove = probs < (min_p * p_max)
+ cfg_logits[indices_to_remove] = remove_logit_value
+
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
@@ -93,8 +104,8 @@ def sample_manual_loop_no_classes(
break
embed, _, _, _ = model.process_tokens([[token]], device)
- embeds = embed.repeat(2, 1, 1)
- attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
+ embeds = embed.repeat(embeds_batch, 1, 1)
+ attention_mask = torch.cat([attention_mask, torch.ones((embeds_batch, 1), device=device, dtype=attention_mask.dtype)], dim=1)
output_audio_codes.append(token - audio_start_id)
progress_bar.update_absolute(step)
@@ -102,24 +113,29 @@ def sample_manual_loop_no_classes(
return output_audio_codes
-def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
+def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0, min_p=0.000):
positive = [[token for token, _ in inner_list] for inner_list in positive]
- negative = [[token for token, _ in inner_list] for inner_list in negative]
positive = positive[0]
- negative = negative[0]
- neg_pad = 0
- if len(negative) < len(positive):
- neg_pad = (len(positive) - len(negative))
- negative = [model.special_tokens["pad"]] * neg_pad + negative
+ if cfg_scale != 1.0:
+ negative = [[token for token, _ in inner_list] for inner_list in negative]
+ negative = negative[0]
- pos_pad = 0
- if len(negative) > len(positive):
- pos_pad = (len(negative) - len(positive))
- positive = [model.special_tokens["pad"]] * pos_pad + positive
+ neg_pad = 0
+ if len(negative) < len(positive):
+ neg_pad = (len(positive) - len(negative))
+ negative = [model.special_tokens["pad"]] * neg_pad + negative
- paddings = [pos_pad, neg_pad]
- return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
+ pos_pad = 0
+ if len(negative) > len(positive):
+ pos_pad = (len(negative) - len(positive))
+ positive = [model.special_tokens["pad"]] * pos_pad + positive
+
+ ids = [positive, negative]
+ else:
+ ids = [positive]
+
+ return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
@@ -129,12 +145,12 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
user_metas = {
k: kwargs.pop(k)
- for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption")
+ for k in ("bpm", "duration", "keyscale", "timesignature")
if k in kwargs
}
timesignature = user_metas.get("timesignature")
if isinstance(timesignature, str) and timesignature.endswith("/4"):
- user_metas["timesignature"] = timesignature.rsplit("/", 1)[0]
+ user_metas["timesignature"] = timesignature[:-2]
user_metas = {
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
for k, v in user_metas.items()
@@ -147,8 +163,11 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
return f"\n{meta_yaml}\n" if not return_yaml else meta_yaml
def _metas_to_cap(self, **kwargs) -> str:
- use_keys = ("bpm", "duration", "keyscale", "timesignature")
+ use_keys = ("bpm", "timesignature", "keyscale", "duration")
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
+ timesignature = user_metas.get("timesignature")
+ if isinstance(timesignature, str) and timesignature.endswith("/4"):
+ user_metas["timesignature"] = timesignature[:-2]
duration = user_metas["duration"]
if duration == "N/A":
user_metas["duration"] = "30 seconds"
@@ -159,9 +178,13 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
- out = {}
+ text = text.strip()
+ text_negative = kwargs.get("caption_negative", text).strip()
lyrics = kwargs.get("lyrics", "")
+ lyrics_negative = kwargs.get("lyrics_negative", lyrics)
duration = kwargs.get("duration", 120)
+ if isinstance(duration, str):
+ duration = float(duration.split(None, 1)[0])
language = kwargs.get("language")
seed = kwargs.get("seed", 0)
@@ -170,28 +193,55 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
temperature = kwargs.get("temperature", 0.85)
top_p = kwargs.get("top_p", 0.9)
top_k = kwargs.get("top_k", 0.0)
-
+ min_p = kwargs.get("min_p", 0.000)
duration = math.ceil(duration)
kwargs["duration"] = duration
+ tokens_duration = duration * 5
+ min_tokens = int(kwargs.get("min_tokens", tokens_duration))
+ max_tokens = int(kwargs.get("max_tokens", tokens_duration))
- cot_text = self._metas_to_cot(caption = text, **kwargs)
+ metas_negative = {
+ k.rsplit("_", 1)[0]: kwargs.pop(k)
+ for k in ("bpm_negative", "duration_negative", "keyscale_negative", "timesignature_negative", "language_negative", "caption_negative")
+ if k in kwargs
+ }
+ if not kwargs.get("use_negative_caption"):
+ _ = metas_negative.pop("caption", None)
+
+ cot_text = self._metas_to_cot(caption=text, **kwargs)
+ cot_text_negative = "\n\n" if not metas_negative else self._metas_to_cot(**metas_negative)
meta_cap = self._metas_to_cap(**kwargs)
- lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n"
+ lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n"
+ lyrics_template = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>"
+ qwen3_06b_template = "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>"
- out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True)
- out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "\n"), disable_weights=True)
+ llm_prompts = {
+ "lm_prompt": lm_template.format(text, lyrics.strip(), cot_text),
+ "lm_prompt_negative": lm_template.format(text_negative, lyrics_negative.strip(), cot_text_negative),
+ "lyrics": lyrics_template.format(language if language is not None else "", lyrics),
+ "qwen3_06b": qwen3_06b_template.format(text, meta_cap),
+ }
- out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
- out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
- out["lm_metadata"] = {"min_tokens": duration * 5,
+ out = {
+ prompt_key: self.qwen3_06b.tokenize_with_weights(
+ prompt,
+ prompt_key == "qwen3_06b" and return_word_ids,
+ disable_weights = True,
+ **kwargs,
+ )
+ for prompt_key, prompt in llm_prompts.items()
+ }
+ out["lm_metadata"] = {"min_tokens": min_tokens,
+ "max_tokens": max_tokens,
"seed": seed,
"generate_audio_codes": generate_audio_codes,
"cfg_scale": cfg_scale,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
+ "min_p": min_p,
}
return out
@@ -252,7 +302,7 @@ class ACE15TEModel(torch.nn.Module):
lm_metadata = token_weight_pairs["lm_metadata"]
if lm_metadata["generate_audio_codes"]:
- audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
+ audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"], min_p=lm_metadata["min_p"])
out["audio_codes"] = [audio_codes]
return base_out, None, out
diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py
index b6735d210..54f3d5595 100644
--- a/comfy/text_encoders/llama.py
+++ b/comfy/text_encoders/llama.py
@@ -355,13 +355,6 @@ class RMSNorm(nn.Module):
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
-
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
if not isinstance(theta, list):
theta = [theta]
@@ -390,20 +383,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
- out.append((cos, sin))
+ sin_split = sin.shape[-1] // 2
+ out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
if len(out) == 1:
return out[0]
return out
-
def apply_rope(xq, xk, freqs_cis):
org_dtype = xq.dtype
cos = freqs_cis[0]
sin = freqs_cis[1]
- q_embed = (xq * cos) + (rotate_half(xq) * sin)
- k_embed = (xk * cos) + (rotate_half(xk) * sin)
+ nsin = freqs_cis[2]
+
+ q_embed = (xq * cos)
+ q_split = q_embed.shape[-1] // 2
+ q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
+ q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
+
+ k_embed = (xk * cos)
+ k_split = k_embed.shape[-1] // 2
+ k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
+ k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
+
return q_embed.to(org_dtype), k_embed.to(org_dtype)
diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py
index 3f87dfd6a..9cf87c0b2 100644
--- a/comfy/text_encoders/lt.py
+++ b/comfy/text_encoders/lt.py
@@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs):
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
- super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
@@ -97,6 +97,7 @@ class LTXAVTEModel(torch.nn.Module):
token_weight_pairs = token_weight_pairs["gemma3_12b"]
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
+ out = out[:, :, -torch.sum(extra["attention_mask"]).item():]
out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
@@ -138,6 +139,7 @@ class LTXAVTEModel(torch.nn.Module):
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
+ num_tokens = max(num_tokens, 64)
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
diff --git a/comfy/utils.py b/comfy/utils.py
index 1337e2205..c1ce540b5 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -20,13 +20,14 @@
import torch
import math
import struct
-import comfy.checkpoint_pickle
+import comfy.memory_management
import safetensors.torch
import numpy as np
from PIL import Image
import logging
import itertools
from torch.nn.functional import interpolate
+from tqdm.auto import trange
from einops import rearrange
from comfy.cli_args import args, enables_dynamic_vram
import json
@@ -37,26 +38,26 @@ import warnings
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
-ALWAYS_SAFE_LOAD = False
-if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
+
+if True: # ckpt/pt file whitelist for safe loading of old sd files
class ModelCheckpoint:
pass
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
def scalar(*args, **kwargs):
- from numpy.core.multiarray import scalar as sc
- return sc(*args, **kwargs)
+ return None
scalar.__module__ = "numpy.core.multiarray"
from numpy import dtype
from numpy.dtypes import Float64DType
- from _codecs import encode
+
+ def encode(*args, **kwargs): # no longer necessary on newer torch
+ return None
+ encode.__module__ = "_codecs"
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
- ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")
-else:
- logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
+
# Current as of safetensors 0.7.0
_TYPES = {
@@ -139,11 +140,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if MMAP_TORCH_FILES:
torch_args["mmap"] = True
- if safe_load or ALWAYS_SAFE_LOAD:
- pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
- else:
- logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
- pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
+ pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
+
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
@@ -674,10 +672,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"ff_context.linear_in.bias": "txt_mlp.0.bias",
"ff_context.linear_out.weight": "txt_mlp.2.weight",
"ff_context.linear_out.bias": "txt_mlp.2.bias",
- "attn.norm_q.weight": "img_attn.norm.query_norm.scale",
- "attn.norm_k.weight": "img_attn.norm.key_norm.scale",
- "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
- "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
+ "attn.norm_q.weight": "img_attn.norm.query_norm.weight",
+ "attn.norm_k.weight": "img_attn.norm.key_norm.weight",
+ "attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
+ "attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
}
for k in block_map:
@@ -700,8 +698,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
- "attn.norm_q.weight": "norm.query_norm.scale",
- "attn.norm_k.weight": "norm.key_norm.scale",
+ "attn.norm_q.weight": "norm.query_norm.weight",
+ "attn.norm_k.weight": "norm.key_norm.weight",
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
"attn.to_out.weight": "linear2.weight", # Flux 2
}
@@ -1155,6 +1153,32 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
+def model_trange(*args, **kwargs):
+ if comfy.memory_management.aimdo_allocator is None:
+ return trange(*args, **kwargs)
+
+ pbar = trange(*args, **kwargs, smoothing=1.0)
+ pbar._i = 0
+ pbar.set_postfix_str(" Model Initializing ... ")
+
+ _update = pbar.update
+
+ def warmup_update(n=1):
+ pbar._i += 1
+ if pbar._i == 1:
+ pbar.i1_time = time.time()
+ pbar.set_postfix_str(" Model Initialization complete! ")
+ elif pbar._i == 2:
+ #bring forward the effective start time based the the diff between first and second iteration
+ #to attempt to remove load overhead from the final step rate estimate.
+ pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
+ pbar.set_postfix_str("")
+
+ _update(n)
+
+ pbar.update = warmup_update
+ return pbar
+
PROGRESS_BAR_ENABLED = True
def set_progress_bar_enabled(enabled):
global PROGRESS_BAR_ENABLED
@@ -1376,3 +1400,21 @@ def string_to_seed(data):
else:
crc >>= 1
return crc ^ 0xFFFFFFFF
+
+def deepcopy_list_dict(obj, memo=None):
+ if memo is None:
+ memo = {}
+
+ obj_id = id(obj)
+ if obj_id in memo:
+ return memo[obj_id]
+
+ if isinstance(obj, dict):
+ res = {deepcopy_list_dict(k, memo): deepcopy_list_dict(v, memo) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ res = [deepcopy_list_dict(i, memo) for i in obj]
+ else:
+ res = obj
+
+ memo[obj_id] = res
+ return res
diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py
index bce89a0e2..d352e066b 100644
--- a/comfy/weight_adapter/base.py
+++ b/comfy/weight_adapter/base.py
@@ -49,6 +49,12 @@ class WeightAdapterBase:
"""
raise NotImplementedError
+ def calculate_shape(
+ self,
+ key
+ ):
+ return None
+
def calculate_weight(
self,
weight,
diff --git a/comfy/weight_adapter/bypass.py b/comfy/weight_adapter/bypass.py
index d4aaf98ca..b9d5ec7d9 100644
--- a/comfy/weight_adapter/bypass.py
+++ b/comfy/weight_adapter/bypass.py
@@ -21,6 +21,7 @@ from typing import Optional, Union
import torch
import torch.nn as nn
+import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase
from comfy.patcher_extension import PatcherInjection
@@ -181,18 +182,21 @@ class BypassForwardHook:
)
return # Already injected
- # Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
- device = None
+ # Move adapter weights to compute device (GPU)
+ # Use get_torch_device() instead of module.weight.device because
+ # with offloading, module weights may be on CPU while compute happens on GPU
+ device = comfy.model_management.get_torch_device()
+
+ # Get dtype from module weight if available
dtype = None
if hasattr(self.module, "weight") and self.module.weight is not None:
- device = self.module.weight.device
dtype = self.module.weight.dtype
- elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
- device = self.module.W_q.device
- dtype = self.module.W_q.dtype
- if device is not None:
- self._move_adapter_weights_to_device(device, dtype)
+ # Only use dtype if it's a standard float type, not quantized
+ if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
+ dtype = None
+
+ self._move_adapter_weights_to_device(device, dtype)
self.original_forward = self.module.forward
self.module.forward = self._bypass_forward
diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py
index bc4260a8f..8e1261a12 100644
--- a/comfy/weight_adapter/lora.py
+++ b/comfy/weight_adapter/lora.py
@@ -214,6 +214,13 @@ class LoRAAdapter(WeightAdapterBase):
else:
return None
+ def calculate_shape(
+ self,
+ key
+ ):
+ reshape = self.weights[5]
+ return tuple(reshape) if reshape is not None else None
+
def calculate_weight(
self,
weight,
diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py
index de167f037..a90a5ca40 100644
--- a/comfy_api/feature_flags.py
+++ b/comfy_api/feature_flags.py
@@ -14,6 +14,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
+ "node_replacements": True,
}
diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py
index 8542a1dbc..f2399422b 100644
--- a/comfy_api/latest/__init__.py
+++ b/comfy_api/latest/__init__.py
@@ -21,6 +21,17 @@ class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest"
STABLE = False
+ def __init__(self):
+ super().__init__()
+ self.node_replacement = self.NodeReplacement()
+ self.execution = self.Execution()
+
+ class NodeReplacement(ProxiedSingleton):
+ async def register(self, node_replace: io.NodeReplace) -> None:
+ """Register a node replacement mapping."""
+ from server import PromptServer
+ PromptServer.instance.node_replace_manager.register(node_replace)
+
class Execution(ProxiedSingleton):
async def set_progress(
self,
@@ -73,8 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
- execution: Execution
-
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py
index e634a0311..451e9526e 100644
--- a/comfy_api/latest/_input/video_types.py
+++ b/comfy_api/latest/_input/video_types.py
@@ -34,6 +34,21 @@ class VideoInput(ABC):
"""
pass
+ @abstractmethod
+ def as_trimmed(
+ self,
+ start_time: float | None = None,
+ duration: float | None = None,
+ strict_duration: bool = False,
+ ) -> VideoInput | None:
+ """
+ Create a new VideoInput which is trimmed to have the corresponding start_time and duration
+
+ Returns:
+ A new VideoInput, or None if the result would have negative duration
+ """
+ pass
+
def get_stream_source(self) -> Union[str, io.BytesIO]:
"""
Get a streamable source for the video. This allows processing without
diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py
index 1405d0b81..3463ed1c9 100644
--- a/comfy_api/latest/_input_impl/video_types.py
+++ b/comfy_api/latest/_input_impl/video_types.py
@@ -6,6 +6,7 @@ from typing import Optional
from .._input import AudioInput, VideoInput
import av
import io
+import itertools
import json
import numpy as np
import math
@@ -29,7 +30,6 @@ def container_to_output_format(container_format: str | None) -> str | None:
formats = container_format.split(",")
return formats[0]
-
def get_open_write_kwargs(
dest: str | io.BytesIO, container_format: str, to_format: str | None
) -> dict:
@@ -57,12 +57,14 @@ class VideoFromFile(VideoInput):
Class representing video input from a file.
"""
- def __init__(self, file: str | io.BytesIO):
+ def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
+ self.__start_time = start_time
+ self.__duration = duration
def get_stream_source(self) -> str | io.BytesIO:
"""
@@ -96,6 +98,16 @@ class VideoFromFile(VideoInput):
Returns:
Duration in seconds
"""
+ raw_duration = self._get_raw_duration()
+ if self.__start_time < 0:
+ duration_from_start = min(raw_duration, -self.__start_time)
+ else:
+ duration_from_start = raw_duration - self.__start_time
+ if self.__duration:
+ return min(self.__duration, duration_from_start)
+ return duration_from_start
+
+ def _get_raw_duration(self) -> float:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
@@ -113,9 +125,13 @@ class VideoFromFile(VideoInput):
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
- for packet in container.demux(video_stream):
- for _ in packet.decode():
- frame_count += 1
+ frame_iterator = (
+ container.decode(video_stream)
+ if video_stream.codec.capabilities & 0x100
+ else container.demux(video_stream)
+ )
+ for packet in frame_iterator:
+ frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
@@ -131,36 +147,54 @@ class VideoFromFile(VideoInput):
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
- # 1. Prefer the frames field if available
- if video_stream.frames and video_stream.frames > 0:
+ # 1. Prefer the frames field if available and usable
+ if (
+ video_stream.frames
+ and video_stream.frames > 0
+ and not self.__start_time
+ and not self.__duration
+ ):
return int(video_stream.frames)
# 2. Try to estimate from duration and average_rate using only metadata
- if container.duration is not None and video_stream.average_rate:
- duration_seconds = float(container.duration / av.time_base)
- estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
- if estimated_frames > 0:
- return estimated_frames
-
if (
getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate
):
- duration_seconds = float(video_stream.duration * video_stream.time_base)
+ raw_duration = float(video_stream.duration * video_stream.time_base)
+ if self.__start_time < 0:
+ duration_from_start = min(raw_duration, -self.__start_time)
+ else:
+ duration_from_start = raw_duration - self.__start_time
+ duration_seconds = min(self.__duration, duration_from_start)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
# 3. Last resort: decode frames and count them (streaming)
- frame_count = 0
- container.seek(0)
- for packet in container.demux(video_stream):
- for _ in packet.decode():
- frame_count += 1
-
- if frame_count == 0:
- raise ValueError(f"Could not determine frame count for file '{self.__file}'")
+ if self.__start_time < 0:
+ start_time = max(self._get_raw_duration() + self.__start_time, 0)
+ else:
+ start_time = self.__start_time
+ frame_count = 1
+ start_pts = int(start_time / video_stream.time_base)
+ end_pts = int((start_time + self.__duration) / video_stream.time_base)
+ container.seek(start_pts, stream=video_stream)
+ frame_iterator = (
+ container.decode(video_stream)
+ if video_stream.codec.capabilities & 0x100
+ else container.demux(video_stream)
+ )
+ for frame in frame_iterator:
+ if frame.pts >= start_pts:
+ break
+ else:
+ raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
+ for frame in frame_iterator:
+ if frame.pts >= end_pts:
+ break
+ frame_count += 1
return frame_count
def get_frame_rate(self) -> Fraction:
@@ -199,9 +233,21 @@ class VideoFromFile(VideoInput):
return container.format.name
def get_components_internal(self, container: InputContainer) -> VideoComponents:
+ video_stream = self._get_first_video_stream(container)
+ if self.__start_time < 0:
+ start_time = max(self._get_raw_duration() + self.__start_time, 0)
+ else:
+ start_time = self.__start_time
# Get video frames
frames = []
- for frame in container.decode(video=0):
+ start_pts = int(start_time / video_stream.time_base)
+ end_pts = int((start_time + self.__duration) / video_stream.time_base)
+ container.seek(start_pts, stream=video_stream)
+ for frame in container.decode(video_stream):
+ if frame.pts < start_pts:
+ continue
+ if self.__duration and frame.pts >= end_pts:
+ break
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
@@ -209,31 +255,44 @@ class VideoFromFile(VideoInput):
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
# Get frame rate
- video_stream = next(s for s in container.streams if s.type == 'video')
- frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
+ frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
# Get audio if available
audio = None
- try:
- container.seek(0) # Reset the container to the beginning
- for stream in container.streams:
- if stream.type != 'audio':
- continue
- assert isinstance(stream, av.AudioStream)
- audio_frames = []
- for packet in container.demux(stream):
- for frame in packet.decode():
- assert isinstance(frame, av.AudioFrame)
- audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
- if len(audio_frames) > 0:
- audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
- audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
- audio = AudioInput({
- "waveform": audio_tensor,
- "sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
- })
- except StopIteration:
- pass # No audio stream
+ container.seek(start_pts, stream=video_stream)
+ # Use last stream for consistency
+ if len(container.streams.audio):
+ audio_stream = container.streams.audio[-1]
+ audio_frames = []
+ resample = av.audio.resampler.AudioResampler(format='fltp').resample
+ frames = itertools.chain.from_iterable(
+ map(resample, container.decode(audio_stream))
+ )
+
+ has_first_frame = False
+ for frame in frames:
+ offset_seconds = start_time - frame.pts * audio_stream.time_base
+ to_skip = int(offset_seconds * audio_stream.sample_rate)
+ if to_skip < frame.samples:
+ has_first_frame = True
+ break
+ if has_first_frame:
+ audio_frames.append(frame.to_ndarray()[..., to_skip:])
+
+ for frame in frames:
+ if frame.time > start_time + self.__duration:
+ break
+ audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
+ if len(audio_frames) > 0:
+ audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
+ if self.__duration:
+ audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
+
+ audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
+ audio = AudioInput({
+ "waveform": audio_tensor,
+ "sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
+ })
metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
@@ -250,7 +309,7 @@ class VideoFromFile(VideoInput):
path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
- metadata: Optional[dict] = None
+ metadata: Optional[dict] = None,
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
@@ -262,15 +321,14 @@ class VideoFromFile(VideoInput):
reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False
+ if self.__start_time or self.__duration:
+ reuse_streams = False
if not reuse_streams:
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
- path,
- format=format,
- codec=codec,
- metadata=metadata
+ path, format=format, codec=codec, metadata=metadata
)
streams = container.streams
@@ -304,10 +362,21 @@ class VideoFromFile(VideoInput):
output_container.mux(packet)
def _get_first_video_stream(self, container: InputContainer):
- video_stream = next((s for s in container.streams if s.type == "video"), None)
- if video_stream is None:
- raise ValueError(f"No video stream found in file '{self.__file}'")
- return video_stream
+ if len(container.streams.video):
+ return container.streams.video[0]
+ raise ValueError(f"No video stream found in file '{self.__file}'")
+
+ def as_trimmed(
+ self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
+ ) -> VideoInput | None:
+ trimmed = VideoFromFile(
+ self.get_stream_source(),
+ start_time=start_time + self.__start_time,
+ duration=duration,
+ )
+ if trimmed.get_duration() < duration and strict_duration:
+ return None
+ return trimmed
class VideoFromComponents(VideoInput):
@@ -322,7 +391,7 @@ class VideoFromComponents(VideoInput):
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
- frame_rate=self.__components.frame_rate
+ frame_rate=self.__components.frame_rate,
)
def save_to(
@@ -330,7 +399,7 @@ class VideoFromComponents(VideoInput):
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
- metadata: Optional[dict] = None
+ metadata: Optional[dict] = None,
):
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
@@ -357,7 +426,10 @@ class VideoFromComponents(VideoInput):
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
- audio_stream = output.add_stream('aac', rate=audio_sample_rate)
+ waveform = self.__components.audio['waveform']
+ waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
+ layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
+ audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
# Encode video
for i, frame in enumerate(self.__components.images):
@@ -372,12 +444,21 @@ class VideoFromComponents(VideoInput):
output.mux(packet)
if audio_stream and self.__components.audio:
- waveform = self.__components.audio['waveform']
- waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
- frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
+ frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout)
frame.sample_rate = audio_sample_rate
frame.pts = 0
output.mux(audio_stream.encode(frame))
# Flush encoder
output.mux(audio_stream.encode(None))
+
+ def as_trimmed(
+ self,
+ start_time: float | None = None,
+ duration: float | None = None,
+ strict_duration: bool = True,
+ ) -> VideoInput | None:
+ if self.get_duration() < start_time + duration:
+ return None
+ #TODO Consider tracking duration and trimming at time of save?
+ return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index 93cf482ca..95d79c035 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -2030,6 +2030,68 @@ class _UIOutput(ABC):
...
+class InputMapOldId(TypedDict):
+ """Map an old node input to a new node input by ID."""
+ new_id: str
+ old_id: str
+
+class InputMapSetValue(TypedDict):
+ """Set a specific value for a new node input."""
+ new_id: str
+ set_value: Any
+
+InputMap = InputMapOldId | InputMapSetValue
+"""
+Input mapping for node replacement. Type is inferred by dictionary keys:
+- {"new_id": str, "old_id": str} - maps old input to new input
+- {"new_id": str, "set_value": Any} - sets a specific value for new input
+"""
+
+class OutputMap(TypedDict):
+ """Map outputs of node replacement via indexes."""
+ new_idx: int
+ old_idx: int
+
+class NodeReplace:
+ """
+ Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
+
+ Also supports assigning specific values to the input widgets of the new node.
+
+ Args:
+ new_node_id: The class name of the new replacement node.
+ old_node_id: The class name of the deprecated node.
+ old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
+ connected. The workflow JSON stores widget values by their relative position index,
+ not by ID. This list maps those positional indexes to input IDs, enabling the
+ replacement system to correctly identify widget values during node migration.
+ input_mapping: List of input mappings from old node to new node.
+ output_mapping: List of output mappings from old node to new node.
+ """
+ def __init__(self,
+ new_node_id: str,
+ old_node_id: str,
+ old_widget_ids: list[str] | None=None,
+ input_mapping: list[InputMap] | None=None,
+ output_mapping: list[OutputMap] | None=None,
+ ):
+ self.new_node_id = new_node_id
+ self.old_node_id = old_node_id
+ self.old_widget_ids = old_widget_ids
+ self.input_mapping = input_mapping
+ self.output_mapping = output_mapping
+
+ def as_dict(self):
+ """Create serializable representation of the node replacement."""
+ return {
+ "new_node_id": self.new_node_id,
+ "old_node_id": self.old_node_id,
+ "old_widget_ids": self.old_widget_ids,
+ "input_mapping": list(self.input_mapping) if self.input_mapping else None,
+ "output_mapping": list(self.output_mapping) if self.output_mapping else None,
+ }
+
+
__all__ = [
"FolderType",
"UploadType",
@@ -2121,4 +2183,5 @@ __all__ = [
"ImageCompare",
"PriceBadgeDepends",
"PriceBadge",
+ "NodeReplace",
]
diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py
index ee2aa1ce6..46a583b5e 100644
--- a/comfy_api_nodes/apis/__init__.py
+++ b/comfy_api_nodes/apis/__init__.py
@@ -1197,12 +1197,6 @@ class KlingImageGenImageReferenceType(str, Enum):
face = 'face'
-class KlingImageGenModelName(str, Enum):
- kling_v1 = 'kling-v1'
- kling_v1_5 = 'kling-v1-5'
- kling_v2 = 'kling-v2'
-
-
class KlingImageGenerationsRequest(BaseModel):
aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9'
callback_url: Optional[AnyUrl] = Field(
@@ -1218,7 +1212,7 @@ class KlingImageGenerationsRequest(BaseModel):
0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0
)
image_reference: Optional[KlingImageGenImageReferenceType] = None
- model_name: Optional[KlingImageGenModelName] = 'kling-v1'
+ model_name: str = Field(...)
n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9)
negative_prompt: Optional[str] = Field(
None, description='Negative text prompt', max_length=200
diff --git a/comfy_api_nodes/apis/bria.py b/comfy_api_nodes/apis/bria.py
index 9119cacc6..8c496b56c 100644
--- a/comfy_api_nodes/apis/bria.py
+++ b/comfy_api_nodes/apis/bria.py
@@ -45,17 +45,55 @@ class BriaEditImageRequest(BaseModel):
)
+class BriaRemoveBackgroundRequest(BaseModel):
+ image: str = Field(...)
+ sync: bool = Field(False)
+ visual_input_content_moderation: bool = Field(
+ False, description="If true, returns 422 on input image moderation failure."
+ )
+ visual_output_content_moderation: bool = Field(
+ False, description="If true, returns 422 on visual output moderation failure."
+ )
+ seed: int = Field(...)
+
+
class BriaStatusResponse(BaseModel):
request_id: str = Field(...)
status_url: str = Field(...)
warning: str | None = Field(None)
-class BriaResult(BaseModel):
+class BriaRemoveBackgroundResult(BaseModel):
+ image_url: str = Field(...)
+
+
+class BriaRemoveBackgroundResponse(BaseModel):
+ status: str = Field(...)
+ result: BriaRemoveBackgroundResult | None = Field(None)
+
+
+class BriaImageEditResult(BaseModel):
structured_prompt: str = Field(...)
image_url: str = Field(...)
-class BriaResponse(BaseModel):
+class BriaImageEditResponse(BaseModel):
status: str = Field(...)
- result: BriaResult | None = Field(None)
+ result: BriaImageEditResult | None = Field(None)
+
+
+class BriaRemoveVideoBackgroundRequest(BaseModel):
+ video: str = Field(...)
+ background_color: str = Field(default="transparent", description="Background color for the output video.")
+ output_container_and_codec: str = Field(...)
+ preserve_audio: bool = Field(True)
+ seed: int = Field(...)
+
+
+class BriaRemoveVideoBackgroundResult(BaseModel):
+ video_url: str = Field(...)
+
+
+class BriaRemoveVideoBackgroundResponse(BaseModel):
+ status: str = Field(...)
+ result: BriaRemoveVideoBackgroundResult | None = Field(None)
diff --git a/comfy_api_nodes/apis/hunyuan3d.py b/comfy_api_nodes/apis/hunyuan3d.py
index 6421c9bd5..e84eba31e 100644
--- a/comfy_api_nodes/apis/hunyuan3d.py
+++ b/comfy_api_nodes/apis/hunyuan3d.py
@@ -64,3 +64,23 @@ class To3DProTaskResultResponse(BaseModel):
class To3DProTaskQueryRequest(BaseModel):
JobId: str = Field(...)
+
+
+class To3DUVFileInput(BaseModel):
+ Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
+ Url: str = Field(...)
+
+
+class To3DUVTaskRequest(BaseModel):
+ File: To3DUVFileInput = Field(...)
+
+
+class TextureEditImageInfo(BaseModel):
+ Url: str = Field(...)
+
+
+class TextureEditTaskRequest(BaseModel):
+ File3D: To3DUVFileInput = Field(...)
+ Image: TextureEditImageInfo | None = Field(None)
+ Prompt: str | None = Field(None)
+ EnablePBR: bool | None = Field(None)
diff --git a/comfy_api_nodes/apis/kling.py b/comfy_api_nodes/apis/kling.py
index bf54ede3e..9c0446075 100644
--- a/comfy_api_nodes/apis/kling.py
+++ b/comfy_api_nodes/apis/kling.py
@@ -1,12 +1,22 @@
from pydantic import BaseModel, Field
+class MultiPromptEntry(BaseModel):
+ index: int = Field(...)
+ prompt: str = Field(...)
+ duration: str = Field(...)
+
+
class OmniProText2VideoRequest(BaseModel):
model_name: str = Field(..., description="kling-video-o1")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
+ multi_shot: bool | None = Field(None)
+ multi_prompt: list[MultiPromptEntry] | None = Field(None)
+ shot_type: str | None = Field(None)
+ sound: str = Field(..., description="'on' or 'off'")
class OmniParamImage(BaseModel):
@@ -26,6 +36,10 @@ class OmniProFirstLastFrameRequest(BaseModel):
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
+ sound: str | None = Field(None, description="'on' or 'off'")
+ multi_shot: bool | None = Field(None)
+ multi_prompt: list[MultiPromptEntry] | None = Field(None)
+ shot_type: str | None = Field(None)
class OmniProReferences2VideoRequest(BaseModel):
@@ -38,6 +52,10 @@ class OmniProReferences2VideoRequest(BaseModel):
duration: str | None = Field(..., description="From 3 to 10.")
prompt: str = Field(...)
mode: str = Field("pro")
+ sound: str | None = Field(None, description="'on' or 'off'")
+ multi_shot: bool | None = Field(None)
+ multi_prompt: list[MultiPromptEntry] | None = Field(None)
+ shot_type: str | None = Field(None)
class TaskStatusVideoResult(BaseModel):
@@ -54,6 +72,7 @@ class TaskStatusImageResult(BaseModel):
class TaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None)
+ series_images: list[TaskStatusImageResult] | None = Field(None)
class TaskStatusResponseData(BaseModel):
@@ -77,31 +96,42 @@ class OmniImageParamImage(BaseModel):
class OmniProImageRequest(BaseModel):
- model_name: str = Field(..., description="kling-image-o1")
- resolution: str = Field(..., description="'1k' or '2k'")
+ model_name: str = Field(...)
+ resolution: str = Field(...)
aspect_ratio: str | None = Field(...)
prompt: str = Field(...)
mode: str = Field("pro")
n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
+ result_type: str | None = Field(None, description="Set to 'series' for series generation")
+ series_amount: int | None = Field(None, ge=2, le=9, description="Number of images in a series")
class TextToVideoWithAudioRequest(BaseModel):
- model_name: str = Field(..., description="kling-v2-6")
+ model_name: str = Field(...)
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
- duration: str = Field(..., description="'5' or '10'")
- prompt: str = Field(...)
+ duration: str = Field(...)
+ prompt: str | None = Field(...)
+ negative_prompt: str | None = Field(None)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
+ multi_shot: bool | None = Field(None)
+ multi_prompt: list[MultiPromptEntry] | None = Field(None)
+ shot_type: str | None = Field(None)
class ImageToVideoWithAudioRequest(BaseModel):
- model_name: str = Field(..., description="kling-v2-6")
+ model_name: str = Field(...)
image: str = Field(...)
- duration: str = Field(..., description="'5' or '10'")
- prompt: str = Field(...)
+ image_tail: str | None = Field(None)
+ duration: str = Field(...)
+ prompt: str | None = Field(...)
+ negative_prompt: str | None = Field(None)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
+ multi_shot: bool | None = Field(None)
+ multi_prompt: list[MultiPromptEntry] | None = Field(None)
+ shot_type: str | None = Field(None)
class MotionControlRequest(BaseModel):
diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py
index d3a52bc1b..4044ee3ea 100644
--- a/comfy_api_nodes/nodes_bria.py
+++ b/comfy_api_nodes/nodes_bria.py
@@ -3,7 +3,11 @@ from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bria import (
BriaEditImageRequest,
- BriaResponse,
+ BriaRemoveBackgroundRequest,
+ BriaRemoveBackgroundResponse,
+ BriaRemoveVideoBackgroundRequest,
+ BriaRemoveVideoBackgroundResponse,
+ BriaImageEditResponse,
BriaStatusResponse,
InputModerationSettings,
)
@@ -11,10 +15,12 @@ from comfy_api_nodes.util import (
ApiEndpoint,
convert_mask_to_image,
download_url_to_image_tensor,
- get_number_of_images,
+ download_url_to_video_output,
poll_op,
sync_op,
- upload_images_to_comfyapi,
+ upload_image_to_comfyapi,
+ upload_video_to_comfyapi,
+ validate_video_duration,
)
@@ -73,21 +79,15 @@ class BriaImageEditNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"moderation",
options=[
+ IO.DynamicCombo.Option("false", []),
IO.DynamicCombo.Option(
"true",
[
- IO.Boolean.Input(
- "prompt_content_moderation", default=False
- ),
- IO.Boolean.Input(
- "visual_input_moderation", default=False
- ),
- IO.Boolean.Input(
- "visual_output_moderation", default=True
- ),
+ IO.Boolean.Input("prompt_content_moderation", default=False),
+ IO.Boolean.Input("visual_input_moderation", default=False),
+ IO.Boolean.Input("visual_output_moderation", default=True),
],
),
- IO.DynamicCombo.Option("false", []),
],
tooltip="Moderation settings",
),
@@ -127,50 +127,26 @@ class BriaImageEditNode(IO.ComfyNode):
mask: Input.Image | None = None,
) -> IO.NodeOutput:
if not prompt and not structured_prompt:
- raise ValueError(
- "One of prompt or structured_prompt is required to be non-empty."
- )
- if get_number_of_images(image) != 1:
- raise ValueError("Exactly one input image is required.")
+ raise ValueError("One of prompt or structured_prompt is required to be non-empty.")
mask_url = None
if mask is not None:
- mask_url = (
- await upload_images_to_comfyapi(
- cls,
- convert_mask_to_image(mask),
- max_images=1,
- mime_type="image/png",
- wait_label="Uploading mask",
- )
- )[0]
+ mask_url = await upload_image_to_comfyapi(cls, convert_mask_to_image(mask), wait_label="Uploading mask")
response = await sync_op(
cls,
ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"),
data=BriaEditImageRequest(
instruction=prompt if prompt else None,
structured_instruction=structured_prompt if structured_prompt else None,
- images=await upload_images_to_comfyapi(
- cls,
- image,
- max_images=1,
- mime_type="image/png",
- wait_label="Uploading image",
- ),
+ images=[await upload_image_to_comfyapi(cls, image, wait_label="Uploading image")],
mask=mask_url,
negative_prompt=negative_prompt if negative_prompt else None,
guidance_scale=guidance_scale,
seed=seed,
model_version=model,
steps_num=steps,
- prompt_content_moderation=moderation.get(
- "prompt_content_moderation", False
- ),
- visual_input_content_moderation=moderation.get(
- "visual_input_moderation", False
- ),
- visual_output_content_moderation=moderation.get(
- "visual_output_moderation", False
- ),
+ prompt_content_moderation=moderation.get("prompt_content_moderation", False),
+ visual_input_content_moderation=moderation.get("visual_input_moderation", False),
+ visual_output_content_moderation=moderation.get("visual_output_moderation", False),
),
response_model=BriaStatusResponse,
)
@@ -178,7 +154,7 @@ class BriaImageEditNode(IO.ComfyNode):
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
- response_model=BriaResponse,
+ response_model=BriaImageEditResponse,
)
return IO.NodeOutput(
await download_url_to_image_tensor(response.result.image_url),
@@ -186,11 +162,167 @@ class BriaImageEditNode(IO.ComfyNode):
)
+class BriaRemoveImageBackground(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="BriaRemoveImageBackground",
+ display_name="Bria Remove Image Background",
+ category="api node/image/Bria",
+ description="Remove the background from an image using Bria RMBG 2.0.",
+ inputs=[
+ IO.Image.Input("image"),
+ IO.DynamicCombo.Input(
+ "moderation",
+ options=[
+ IO.DynamicCombo.Option("false", []),
+ IO.DynamicCombo.Option(
+ "true",
+ [
+ IO.Boolean.Input("visual_input_moderation", default=False),
+ IO.Boolean.Input("visual_output_moderation", default=True),
+ ],
+ ),
+ ],
+ tooltip="Moderation settings",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[IO.Image.Output()],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ expr="""{"type":"usd","usd":0.018}""",
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ image: Input.Image,
+ moderation: dict,
+ seed: int,
+ ) -> IO.NodeOutput:
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/bria/v2/image/edit/remove_background", method="POST"),
+ data=BriaRemoveBackgroundRequest(
+ image=await upload_image_to_comfyapi(cls, image, wait_label="Uploading image"),
+ sync=False,
+ visual_input_content_moderation=moderation.get("visual_input_moderation", False),
+ visual_output_content_moderation=moderation.get("visual_output_moderation", False),
+ seed=seed,
+ ),
+ response_model=BriaStatusResponse,
+ )
+ response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
+ status_extractor=lambda r: r.status,
+ response_model=BriaRemoveBackgroundResponse,
+ )
+ return IO.NodeOutput(await download_url_to_image_tensor(response.result.image_url))
+
+
+class BriaRemoveVideoBackground(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="BriaRemoveVideoBackground",
+ display_name="Bria Remove Video Background",
+ category="api node/video/Bria",
+ description="Remove the background from a video using Bria. ",
+ inputs=[
+ IO.Video.Input("video"),
+ IO.Combo.Input(
+ "background_color",
+ options=[
+ "Black",
+ "White",
+ "Gray",
+ "Red",
+ "Green",
+ "Blue",
+ "Yellow",
+ "Cyan",
+ "Magenta",
+ "Orange",
+ ],
+ tooltip="Background color for the output video.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[IO.Video.Output()],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ video: Input.Video,
+ background_color: str,
+ seed: int,
+ ) -> IO.NodeOutput:
+ validate_video_duration(video, max_duration=60.0)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"),
+ data=BriaRemoveVideoBackgroundRequest(
+ video=await upload_video_to_comfyapi(cls, video),
+ background_color=background_color,
+ output_container_and_codec="mp4_h264",
+ seed=seed,
+ ),
+ response_model=BriaStatusResponse,
+ )
+ response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
+ status_extractor=lambda r: r.status,
+ response_model=BriaRemoveVideoBackgroundResponse,
+ )
+ return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
+
+
class BriaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
BriaImageEditNode,
+ BriaRemoveImageBackground,
+ BriaRemoveVideoBackground,
]
diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py
index 813a7c809..ca002cc60 100644
--- a/comfy_api_nodes/nodes_hunyuan3d.py
+++ b/comfy_api_nodes/nodes_hunyuan3d.py
@@ -1,31 +1,48 @@
from typing_extensions import override
-from comfy_api.latest import IO, ComfyExtension, Input
+from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.hunyuan3d import (
Hunyuan3DViewImage,
InputGenerateType,
ResultFile3D,
+ TextureEditTaskRequest,
To3DProTaskCreateResponse,
To3DProTaskQueryRequest,
To3DProTaskRequest,
To3DProTaskResultResponse,
+ To3DUVFileInput,
+ To3DUVTaskRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_file_3d,
+ download_url_to_image_tensor,
downscale_image_tensor_by_max_side,
poll_op,
sync_op,
+ upload_3d_model_to_comfyapi,
upload_image_to_comfyapi,
validate_image_dimensions,
validate_string,
)
-def get_file_from_response(response_objs: list[ResultFile3D], file_type: str) -> ResultFile3D | None:
+def _is_tencent_rate_limited(status: int, body: object) -> bool:
+ return (
+ status == 400
+ and isinstance(body, dict)
+ and "RequestLimitExceeded" in str(body.get("Response", {}).get("Error", {}).get("Code", ""))
+ )
+
+
+def get_file_from_response(
+ response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
+) -> ResultFile3D | None:
for i in response_objs:
if i.Type.lower() == file_type.lower():
return i
+ if raise_if_not_found:
+ raise ValueError(f"'{file_type}' file type is not found in the response.")
return None
@@ -35,7 +52,7 @@ class TencentTextToModelNode(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TencentTextToModelNode",
- display_name="Hunyuan3D: Text to Model (Pro)",
+ display_name="Hunyuan3D: Text to Model",
category="api node/3d/Tencent",
inputs=[
IO.Combo.Input(
@@ -120,6 +137,7 @@ class TencentTextToModelNode(IO.ComfyNode):
EnablePBR=generate_type.get("pbr", None),
PolygonType=generate_type.get("polygon_type", None),
),
+ is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
@@ -131,11 +149,14 @@ class TencentTextToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
- glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
- obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
- file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
return IO.NodeOutput(
- file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
+ f"{task_id}.glb",
+ await download_url_to_file_3d(
+ get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
+ ),
+ await download_url_to_file_3d(
+ get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
+ ),
)
@@ -145,7 +166,7 @@ class TencentImageToModelNode(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TencentImageToModelNode",
- display_name="Hunyuan3D: Image(s) to Model (Pro)",
+ display_name="Hunyuan3D: Image(s) to Model",
category="api node/3d/Tencent",
inputs=[
IO.Combo.Input(
@@ -268,6 +289,7 @@ class TencentImageToModelNode(IO.ComfyNode):
EnablePBR=generate_type.get("pbr", None),
PolygonType=generate_type.get("polygon_type", None),
),
+ is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
@@ -279,11 +301,257 @@ class TencentImageToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
- glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
- obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
- file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
return IO.NodeOutput(
- file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
+ f"{task_id}.glb",
+ await download_url_to_file_3d(
+ get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
+ ),
+ await download_url_to_file_3d(
+ get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
+ ),
+ )
+
+
+class TencentModelTo3DUVNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="TencentModelTo3DUVNode",
+ display_name="Hunyuan3D: Model to UV",
+ category="api node/3d/Tencent",
+ description="Perform UV unfolding on a 3D model to generate UV texture. "
+ "Input model must have less than 30000 faces.",
+ inputs=[
+ IO.MultiType.Input(
+ "model_3d",
+ types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DFBX, IO.File3DAny],
+ tooltip="Input 3D model (GLB, OBJ, or FBX)",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=1,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[
+ IO.File3DOBJ.Output(display_name="OBJ"),
+ IO.File3DFBX.Output(display_name="FBX"),
+ IO.Image.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.2}'),
+ )
+
+ SUPPORTED_FORMATS = {"glb", "obj", "fbx"}
+
+ @classmethod
+ async def execute(
+ cls,
+ model_3d: Types.File3D,
+ seed: int,
+ ) -> IO.NodeOutput:
+ _ = seed
+ file_format = model_3d.format.lower()
+ if file_format not in cls.SUPPORTED_FORMATS:
+ raise ValueError(
+ f"Unsupported file format: '{file_format}'. "
+ f"Supported formats: {', '.join(sorted(cls.SUPPORTED_FORMATS))}."
+ )
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"),
+ response_model=To3DProTaskCreateResponse,
+ data=To3DUVTaskRequest(
+ File=To3DUVFileInput(
+ Type=file_format.upper(),
+ Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format),
+ )
+ ),
+ is_rate_limited=_is_tencent_rate_limited,
+ )
+ if response.Error:
+ raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
+ result = await poll_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv/query", method="POST"),
+ data=To3DProTaskQueryRequest(JobId=response.JobId),
+ response_model=To3DProTaskResultResponse,
+ status_extractor=lambda r: r.Status,
+ )
+ return IO.NodeOutput(
+ await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
+ await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
+ await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url),
+ )
+
+
+class Tencent3DTextureEditNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="Tencent3DTextureEditNode",
+ display_name="Hunyuan3D: 3D Texture Edit",
+ category="api node/3d/Tencent",
+ description="After inputting the 3D model, perform 3D model texture redrawing.",
+ inputs=[
+ IO.MultiType.Input(
+ "model_3d",
+ types=[IO.File3DFBX, IO.File3DAny],
+ tooltip="3D model in FBX format. Model should have less than 100000 faces.",
+ ),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ default="",
+ tooltip="Describes texture editing. Supports up to 1024 UTF-8 characters.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ expr="""{"type":"usd","usd": 0.6}""",
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_3d: Types.File3D,
+ prompt: str,
+ seed: int,
+ ) -> IO.NodeOutput:
+ _ = seed
+ file_format = model_3d.format.lower()
+ if file_format != "fbx":
+ raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
+ validate_string(prompt, field_name="prompt", min_length=1, max_length=1024)
+ model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"),
+ response_model=To3DProTaskCreateResponse,
+ data=TextureEditTaskRequest(
+ File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
+ Prompt=prompt,
+ EnablePBR=True,
+ ),
+ is_rate_limited=_is_tencent_rate_limited,
+ )
+ if response.Error:
+ raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
+
+ result = await poll_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit/query", method="POST"),
+ data=To3DProTaskQueryRequest(JobId=response.JobId),
+ response_model=To3DProTaskResultResponse,
+ status_extractor=lambda r: r.Status,
+ )
+ return IO.NodeOutput(
+ await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
+ await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
+ )
+
+
+class Tencent3DPartNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="Tencent3DPartNode",
+ display_name="Hunyuan3D: 3D Part",
+ category="api node/3d/Tencent",
+ description="Automatically perform component identification and generation based on the model structure.",
+ inputs=[
+ IO.MultiType.Input(
+ "model_3d",
+ types=[IO.File3DFBX, IO.File3DAny],
+ tooltip="3D model in FBX format. Model should have less than 30000 faces.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[
+ IO.File3DFBX.Output(display_name="FBX"),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.6}'),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_3d: Types.File3D,
+ seed: int,
+ ) -> IO.NodeOutput:
+ _ = seed
+ file_format = model_3d.format.lower()
+ if file_format != "fbx":
+ raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
+ model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"),
+ response_model=To3DProTaskCreateResponse,
+ data=To3DUVTaskRequest(
+ File=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
+ ),
+ is_rate_limited=_is_tencent_rate_limited,
+ )
+ if response.Error:
+ raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
+ result = await poll_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part/query", method="POST"),
+ data=To3DProTaskQueryRequest(JobId=response.JobId),
+ response_model=To3DProTaskResultResponse,
+ status_extractor=lambda r: r.Status,
+ )
+ return IO.NodeOutput(
+ await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
)
@@ -293,6 +561,9 @@ class TencentHunyuan3DExtension(ComfyExtension):
return [
TencentTextToModelNode,
TencentImageToModelNode,
+ # TencentModelTo3DUVNode,
+ # Tencent3DTextureEditNode,
+ Tencent3DPartNode,
]
diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py
index 739fe1855..b89c85561 100644
--- a/comfy_api_nodes/nodes_kling.py
+++ b/comfy_api_nodes/nodes_kling.py
@@ -38,7 +38,6 @@ from comfy_api_nodes.apis import (
KlingImageGenerationsRequest,
KlingImageGenerationsResponse,
KlingImageGenImageReferenceType,
- KlingImageGenModelName,
KlingImageGenAspectRatio,
KlingVideoEffectsRequest,
KlingVideoEffectsResponse,
@@ -52,6 +51,7 @@ from comfy_api_nodes.apis import (
from comfy_api_nodes.apis.kling import (
ImageToVideoWithAudioRequest,
MotionControlRequest,
+ MultiPromptEntry,
OmniImageParamImage,
OmniParamImage,
OmniParamVideo,
@@ -71,6 +71,7 @@ from comfy_api_nodes.util import (
sync_op,
tensor_to_base64_string,
upload_audio_to_comfyapi,
+ upload_image_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_image_aspect_ratio,
@@ -80,6 +81,31 @@ from comfy_api_nodes.util import (
validate_video_duration,
)
+
+def _generate_storyboard_inputs(count: int) -> list:
+ inputs = []
+ for i in range(1, count + 1):
+ inputs.extend(
+ [
+ IO.String.Input(
+ f"storyboard_{i}_prompt",
+ multiline=True,
+ default="",
+ tooltip=f"Prompt for storyboard segment {i}. Max 512 characters.",
+ ),
+ IO.Int.Input(
+ f"storyboard_{i}_duration",
+ default=4,
+ min=1,
+ max=15,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip=f"Duration for storyboard segment {i} in seconds.",
+ ),
+ ]
+ )
+ return inputs
+
+
KLING_API_VERSION = "v1"
PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video"
PATH_IMAGE_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/image2video"
@@ -820,20 +846,48 @@ class OmniProTextToVideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProTextToVideoNode",
- display_name="Kling Omni Text to Video (Pro)",
+ display_name="Kling 3.0 Omni Text to Video",
category="api node/video/Kling",
description="Use text prompts to generate videos with the latest Kling model.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
- "This can include both positive and negative descriptions.",
+ "This can include both positive and negative descriptions. "
+ "Ignored when storyboards are enabled.",
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
- IO.Combo.Input("duration", options=[5, 10]),
+ IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
+ IO.DynamicCombo.Input(
+ "storyboards",
+ options=[
+ IO.DynamicCombo.Option("disabled", []),
+ IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)),
+ IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)),
+ IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)),
+ IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)),
+ IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)),
+ IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)),
+ ],
+ tooltip="Generate a series of video segments with individual prompts and durations. "
+ "Ignored for o1 model.",
+ optional=True,
+ ),
+ IO.Boolean.Input("generate_audio", default=False, optional=True),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Video.Output(),
@@ -845,11 +899,15 @@ class OmniProTextToVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
+ depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
- $rates := {"std": 0.084, "pro": 0.112};
+ $isV3 := $contains(widgets.model_name, "v3");
+ $audio := $isV3 and widgets.generate_audio;
+ $rates := $audio
+ ? {"std": 0.112, "pro": 0.14}
+ : {"std": 0.084, "pro": 0.112};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
@@ -864,8 +922,45 @@ class OmniProTextToVideoNode(IO.ComfyNode):
aspect_ratio: str,
duration: int,
resolution: str = "1080p",
+ storyboards: dict | None = None,
+ generate_audio: bool = False,
+ seed: int = 0,
) -> IO.NodeOutput:
- validate_string(prompt, min_length=1, max_length=2500)
+ _ = seed
+ if model_name == "kling-video-o1":
+ if duration not in (5, 10):
+ raise ValueError("kling-video-o1 only supports durations of 5 or 10 seconds.")
+ if generate_audio:
+ raise ValueError("kling-video-o1 does not support audio generation.")
+ stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
+ if stories_enabled and model_name == "kling-video-o1":
+ raise ValueError("kling-video-o1 does not support storyboards.")
+ validate_string(prompt, strip_whitespace=True, min_length=0 if stories_enabled else 1, max_length=2500)
+
+ multi_shot = None
+ multi_prompt_list = None
+ if stories_enabled:
+ count = int(storyboards["storyboards"].split()[0])
+ multi_shot = True
+ multi_prompt_list = []
+ for i in range(1, count + 1):
+ sb_prompt = storyboards[f"storyboard_{i}_prompt"]
+ sb_duration = storyboards[f"storyboard_{i}_duration"]
+ validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512)
+ multi_prompt_list.append(
+ MultiPromptEntry(
+ index=i,
+ prompt=sb_prompt,
+ duration=str(sb_duration),
+ )
+ )
+ total_storyboard_duration = sum(int(e.duration) for e in multi_prompt_list)
+ if total_storyboard_duration != duration:
+ raise ValueError(
+ f"Total storyboard duration ({total_storyboard_duration}s) "
+ f"must equal the global duration ({duration}s)."
+ )
+
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
@@ -876,6 +971,10 @@ class OmniProTextToVideoNode(IO.ComfyNode):
aspect_ratio=aspect_ratio,
duration=str(duration),
mode="pro" if resolution == "1080p" else "std",
+ multi_shot=multi_shot,
+ multi_prompt=multi_prompt_list,
+ shot_type="customize" if multi_shot else None,
+ sound="on" if generate_audio else "off",
),
)
return await finish_omni_video_task(cls, response)
@@ -887,24 +986,26 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProFirstLastFrameNode",
- display_name="Kling Omni First-Last-Frame to Video (Pro)",
+ display_name="Kling 3.0 Omni First-Last-Frame to Video",
category="api node/video/Kling",
description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
- "This can include both positive and negative descriptions.",
+ "This can include both positive and negative descriptions. "
+ "Ignored when storyboards are enabled.",
),
- IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider),
+ IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider),
IO.Image.Input("first_frame"),
IO.Image.Input(
"end_frame",
optional=True,
tooltip="An optional end frame for the video. "
- "This cannot be used simultaneously with 'reference_images'.",
+ "This cannot be used simultaneously with 'reference_images'. "
+ "Does not work with storyboards.",
),
IO.Image.Input(
"reference_images",
@@ -912,6 +1013,38 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
tooltip="Up to 6 additional reference images.",
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
+ IO.DynamicCombo.Input(
+ "storyboards",
+ options=[
+ IO.DynamicCombo.Option("disabled", []),
+ IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)),
+ IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)),
+ IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)),
+ IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)),
+ IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)),
+ IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)),
+ ],
+ tooltip="Generate a series of video segments with individual prompts and durations. "
+ "Only supported for kling-v3-omni.",
+ optional=True,
+ ),
+ IO.Boolean.Input(
+ "generate_audio",
+ default=False,
+ optional=True,
+ tooltip="Generate audio for the video. Only supported for kling-v3-omni.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Video.Output(),
@@ -923,11 +1056,15 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
+ depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
- $rates := {"std": 0.084, "pro": 0.112};
+ $isV3 := $contains(widgets.model_name, "v3");
+ $audio := $isV3 and widgets.generate_audio;
+ $rates := $audio
+ ? {"std": 0.112, "pro": 0.14}
+ : {"std": 0.084, "pro": 0.112};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
@@ -944,15 +1081,59 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
end_frame: Input.Image | None = None,
reference_images: Input.Image | None = None,
resolution: str = "1080p",
+ storyboards: dict | None = None,
+ generate_audio: bool = False,
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
+ if model_name == "kling-video-o1":
+ if duration > 10:
+ raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.")
+ if generate_audio:
+ raise ValueError("kling-video-o1 does not support audio generation.")
+ stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
+ if stories_enabled and model_name == "kling-video-o1":
+ raise ValueError("kling-video-o1 does not support storyboards.")
prompt = normalize_omni_prompt_references(prompt)
- validate_string(prompt, min_length=1, max_length=2500)
+ validate_string(prompt, strip_whitespace=True, min_length=0 if stories_enabled else 1, max_length=2500)
if end_frame is not None and reference_images is not None:
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
- if duration not in (5, 10) and end_frame is None and reference_images is None:
+ if end_frame is not None and stories_enabled:
+ raise ValueError("The 'end_frame' input cannot be used simultaneously with storyboards.")
+ if (
+ model_name == "kling-video-o1"
+ and duration not in (5, 10)
+ and end_frame is None
+ and reference_images is None
+ ):
raise ValueError(
"Duration is only supported for 5 or 10 seconds if there is no end frame or reference images."
)
+
+ multi_shot = None
+ multi_prompt_list = None
+ if stories_enabled:
+ count = int(storyboards["storyboards"].split()[0])
+ multi_shot = True
+ multi_prompt_list = []
+ for i in range(1, count + 1):
+ sb_prompt = storyboards[f"storyboard_{i}_prompt"]
+ sb_duration = storyboards[f"storyboard_{i}_duration"]
+ validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512)
+ multi_prompt_list.append(
+ MultiPromptEntry(
+ index=i,
+ prompt=sb_prompt,
+ duration=str(sb_duration),
+ )
+ )
+ total_storyboard_duration = sum(int(e.duration) for e in multi_prompt_list)
+ if total_storyboard_duration != duration:
+ raise ValueError(
+ f"Total storyboard duration ({total_storyboard_duration}s) "
+ f"must equal the global duration ({duration}s)."
+ )
+
validate_image_dimensions(first_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
image_list: list[OmniParamImage] = [
@@ -988,6 +1169,10 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
duration=str(duration),
image_list=image_list,
mode="pro" if resolution == "1080p" else "std",
+ sound="on" if generate_audio else "off",
+ multi_shot=multi_shot,
+ multi_prompt=multi_prompt_list,
+ shot_type="customize" if multi_shot else None,
),
)
return await finish_omni_video_task(cls, response)
@@ -999,24 +1184,57 @@ class OmniProImageToVideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProImageToVideoNode",
- display_name="Kling Omni Image to Video (Pro)",
+ display_name="Kling 3.0 Omni Image to Video",
category="api node/video/Kling",
description="Use up to 7 reference images to generate a video with the latest Kling model.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
- "This can include both positive and negative descriptions.",
+ "This can include both positive and negative descriptions. "
+ "Ignored when storyboards are enabled.",
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
- IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
+ IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider),
IO.Image.Input(
"reference_images",
tooltip="Up to 7 reference images.",
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
+ IO.DynamicCombo.Input(
+ "storyboards",
+ options=[
+ IO.DynamicCombo.Option("disabled", []),
+ IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)),
+ IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)),
+ IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)),
+ IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)),
+ IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)),
+ IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)),
+ ],
+ tooltip="Generate a series of video segments with individual prompts and durations. "
+ "Only supported for kling-v3-omni.",
+ optional=True,
+ ),
+ IO.Boolean.Input(
+ "generate_audio",
+ default=False,
+ optional=True,
+ tooltip="Generate audio for the video. Only supported for kling-v3-omni.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Video.Output(),
@@ -1028,11 +1246,15 @@ class OmniProImageToVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
+ depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
- $rates := {"std": 0.084, "pro": 0.112};
+ $isV3 := $contains(widgets.model_name, "v3");
+ $audio := $isV3 and widgets.generate_audio;
+ $rates := $audio
+ ? {"std": 0.112, "pro": 0.14}
+ : {"std": 0.084, "pro": 0.112};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
@@ -1048,9 +1270,46 @@ class OmniProImageToVideoNode(IO.ComfyNode):
duration: int,
reference_images: Input.Image,
resolution: str = "1080p",
+ storyboards: dict | None = None,
+ generate_audio: bool = False,
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
+ if model_name == "kling-video-o1":
+ if duration > 10:
+ raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.")
+ if generate_audio:
+ raise ValueError("kling-video-o1 does not support audio generation.")
+ stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
+ if stories_enabled and model_name == "kling-video-o1":
+ raise ValueError("kling-video-o1 does not support storyboards.")
prompt = normalize_omni_prompt_references(prompt)
- validate_string(prompt, min_length=1, max_length=2500)
+ validate_string(prompt, strip_whitespace=True, min_length=0 if stories_enabled else 1, max_length=2500)
+
+ multi_shot = None
+ multi_prompt_list = None
+ if stories_enabled:
+ count = int(storyboards["storyboards"].split()[0])
+ multi_shot = True
+ multi_prompt_list = []
+ for i in range(1, count + 1):
+ sb_prompt = storyboards[f"storyboard_{i}_prompt"]
+ sb_duration = storyboards[f"storyboard_{i}_duration"]
+ validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512)
+ multi_prompt_list.append(
+ MultiPromptEntry(
+ index=i,
+ prompt=sb_prompt,
+ duration=str(sb_duration),
+ )
+ )
+ total_storyboard_duration = sum(int(e.duration) for e in multi_prompt_list)
+ if total_storyboard_duration != duration:
+ raise ValueError(
+ f"Total storyboard duration ({total_storyboard_duration}s) "
+ f"must equal the global duration ({duration}s)."
+ )
+
if get_number_of_images(reference_images) > 7:
raise ValueError("The maximum number of reference images is 7.")
for i in reference_images:
@@ -1070,6 +1329,10 @@ class OmniProImageToVideoNode(IO.ComfyNode):
duration=str(duration),
image_list=image_list,
mode="pro" if resolution == "1080p" else "std",
+ sound="on" if generate_audio else "off",
+ multi_shot=multi_shot,
+ multi_prompt=multi_prompt_list,
+ shot_type="customize" if multi_shot else None,
),
)
return await finish_omni_video_task(cls, response)
@@ -1081,11 +1344,11 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProVideoToVideoNode",
- display_name="Kling Omni Video to Video (Pro)",
+ display_name="Kling 3.0 Omni Video to Video",
category="api node/video/Kling",
description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
@@ -1102,6 +1365,17 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
optional=True,
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Video.Output(),
@@ -1135,7 +1409,9 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
keep_original_sound: bool,
reference_images: Input.Image | None = None,
resolution: str = "1080p",
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
@@ -1179,11 +1455,11 @@ class OmniProEditVideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProEditVideoNode",
- display_name="Kling Omni Edit Video (Pro)",
+ display_name="Kling 3.0 Omni Edit Video",
category="api node/video/Kling",
description="Edit an existing video with the latest model from Kling.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
@@ -1198,6 +1474,17 @@ class OmniProEditVideoNode(IO.ComfyNode):
optional=True,
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Video.Output(),
@@ -1229,7 +1516,9 @@ class OmniProEditVideoNode(IO.ComfyNode):
keep_original_sound: bool,
reference_images: Input.Image | None = None,
resolution: str = "1080p",
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(video, min_duration=3.0, max_duration=10.05)
@@ -1273,27 +1562,43 @@ class OmniProImageNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProImageNode",
- display_name="Kling Omni Image (Pro)",
+ display_name="Kling 3.0 Omni Image",
category="api node/image/Kling",
description="Create or edit images with the latest model from Kling.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-image-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the image content. "
"This can include both positive and negative descriptions.",
),
- IO.Combo.Input("resolution", options=["1K", "2K"]),
+ IO.Combo.Input("resolution", options=["1K", "2K", "4K"]),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"],
),
+ IO.Combo.Input(
+ "series_amount",
+ options=["disabled", "2", "3", "4", "5", "6", "7", "8", "9"],
+ tooltip="Generate a series of images. Not supported for kling-image-o1.",
+ ),
IO.Image.Input(
"reference_images",
tooltip="Up to 10 additional reference images.",
optional=True,
),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Image.Output(),
@@ -1305,7 +1610,16 @@ class OmniProImageNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- expr="""{"type":"usd","usd":0.028}""",
+ depends_on=IO.PriceBadgeDepends(widgets=["resolution", "series_amount", "model_name"]),
+ expr="""
+ (
+ $prices := {"1k": 0.028, "2k": 0.028, "4k": 0.056};
+ $base := $lookup($prices, widgets.resolution);
+ $isO1 := widgets.model_name = "kling-image-o1";
+ $mult := ($isO1 or widgets.series_amount = "disabled") ? 1 : $number(widgets.series_amount);
+ {"type":"usd","usd": $base * $mult}
+ )
+ """,
),
)
@@ -1316,8 +1630,13 @@ class OmniProImageNode(IO.ComfyNode):
prompt: str,
resolution: str,
aspect_ratio: str,
+ series_amount: str = "disabled",
reference_images: Input.Image | None = None,
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
+ if model_name == "kling-image-o1" and resolution == "4K":
+ raise ValueError("4K resolution is not supported for kling-image-o1 model.")
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
image_list: list[OmniImageParamImage] = []
@@ -1329,6 +1648,9 @@ class OmniProImageNode(IO.ComfyNode):
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniImageParamImage(image=i))
+ use_series = series_amount != "disabled"
+ if use_series and model_name == "kling-image-o1":
+ raise ValueError("kling-image-o1 does not support series generation.")
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
@@ -1339,6 +1661,8 @@ class OmniProImageNode(IO.ComfyNode):
resolution=resolution.lower(),
aspect_ratio=aspect_ratio,
image_list=image_list if image_list else None,
+ result_type="series" if use_series else None,
+ series_amount=int(series_amount) if use_series else None,
),
)
if response.code:
@@ -1351,7 +1675,9 @@ class OmniProImageNode(IO.ComfyNode):
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
- return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
+ images = final_response.data.task_result.series_images or final_response.data.task_result.images
+ tensors = [await download_url_to_image_tensor(img.url) for img in images]
+ return IO.NodeOutput(torch.cat(tensors, dim=0))
class KlingCameraControlT2VNode(IO.ComfyNode):
@@ -2119,7 +2445,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImageGenerationNode",
- display_name="Kling Image Generation",
+ display_name="Kling 3.0 Image",
category="api node/image/Kling",
description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.",
inputs=[
@@ -2147,11 +2473,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
display_mode=IO.NumberDisplay.slider,
tooltip="Subject reference similarity",
),
- IO.Combo.Input(
- "model_name",
- options=[i.value for i in KlingImageGenModelName],
- default="kling-v2",
- ),
+ IO.Combo.Input("model_name", options=["kling-v3", "kling-v2", "kling-v1-5"]),
IO.Combo.Input(
"aspect_ratio",
options=[i.value for i in KlingImageGenAspectRatio],
@@ -2165,6 +2487,17 @@ class KlingImageGenerationNode(IO.ComfyNode):
tooltip="Number of generated images",
),
IO.Image.Input("image", optional=True),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Image.Output(),
@@ -2183,7 +2516,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
$base :=
$contains($m,"kling-v1-5")
? (inputs.image.connected ? 0.028 : 0.014)
- : ($contains($m,"kling-v1") ? 0.0035 : 0.014);
+ : $contains($m,"kling-v3") ? 0.028 : 0.014;
{"type":"usd","usd": $base * widgets.n}
)
""",
@@ -2193,7 +2526,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
- model_name: KlingImageGenModelName,
+ model_name: str,
prompt: str,
negative_prompt: str,
image_type: KlingImageGenImageReferenceType,
@@ -2202,17 +2535,11 @@ class KlingImageGenerationNode(IO.ComfyNode):
n: int,
aspect_ratio: KlingImageGenAspectRatio,
image: torch.Tensor | None = None,
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN)
-
- if image is None:
- image_type = None
- elif model_name == KlingImageGenModelName.kling_v1:
- raise ValueError(f"The model {KlingImageGenModelName.kling_v1.value} does not support reference images.")
- else:
- image = tensor_to_base64_string(image)
-
task_creation_response = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_GENERATIONS, method="POST"),
@@ -2221,8 +2548,8 @@ class KlingImageGenerationNode(IO.ComfyNode):
model_name=model_name,
prompt=prompt,
negative_prompt=negative_prompt,
- image=image,
- image_reference=image_type,
+ image=tensor_to_base64_string(image) if image is not None else None,
+ image_reference=image_type if image is not None else None,
image_fidelity=image_fidelity,
human_fidelity=human_fidelity,
n=n,
@@ -2252,7 +2579,7 @@ class TextToVideoWithAudio(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingTextToVideoWithAudio",
- display_name="Kling Text to Video with Audio",
+ display_name="Kling 2.6 Text to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
@@ -2320,7 +2647,7 @@ class ImageToVideoWithAudio(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImageToVideoWithAudio",
- display_name="Kling Image(First Frame) to Video with Audio",
+ display_name="Kling 2.6 Image(First Frame) to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
@@ -2478,6 +2805,335 @@ class MotionControl(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
+class KlingVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingVideoNode",
+ display_name="Kling 3.0 Video",
+ category="api node/video/Kling",
+ description="Generate videos with Kling V3. "
+ "Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.",
+ inputs=[
+ IO.DynamicCombo.Input(
+ "multi_shot",
+ options=[
+ IO.DynamicCombo.Option(
+ "disabled",
+ [
+ IO.String.Input("prompt", multiline=True, default=""),
+ IO.String.Input("negative_prompt", multiline=True, default=""),
+ IO.Int.Input(
+ "duration",
+ default=5,
+ min=3,
+ max=15,
+ display_mode=IO.NumberDisplay.slider,
+ ),
+ ],
+ ),
+ IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)),
+ IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)),
+ IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)),
+ IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)),
+ IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)),
+ IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)),
+ ],
+ tooltip="Generate a series of video segments with individual prompts and durations.",
+ ),
+ IO.Boolean.Input("generate_audio", default=True),
+ IO.DynamicCombo.Input(
+ "model",
+ options=[
+ IO.DynamicCombo.Option(
+ "kling-v3",
+ [
+ IO.Combo.Input("resolution", options=["1080p", "720p"]),
+ IO.Combo.Input(
+ "aspect_ratio",
+ options=["16:9", "9:16", "1:1"],
+ tooltip="Ignored in image-to-video mode.",
+ ),
+ ],
+ ),
+ ],
+ tooltip="Model and generation settings.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ IO.Image.Input(
+ "start_frame",
+ optional=True,
+ tooltip="Optional start frame image. When connected, switches to image-to-video mode.",
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(
+ widgets=[
+ "model.resolution",
+ "generate_audio",
+ "multi_shot",
+ "multi_shot.duration",
+ "multi_shot.storyboard_1_duration",
+ "multi_shot.storyboard_2_duration",
+ "multi_shot.storyboard_3_duration",
+ "multi_shot.storyboard_4_duration",
+ "multi_shot.storyboard_5_duration",
+ "multi_shot.storyboard_6_duration",
+ ],
+ ),
+ expr="""
+ (
+ $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}};
+ $res := $lookup(widgets, "model.resolution");
+ $audio := widgets.generate_audio ? "on" : "off";
+ $rate := $lookup($lookup($rates, $res), $audio);
+ $ms := widgets.multi_shot;
+ $isSb := $ms != "disabled";
+ $n := $isSb ? $number($substring($ms, 0, 1)) : 0;
+ $d1 := $lookup(widgets, "multi_shot.storyboard_1_duration");
+ $d2 := $n >= 2 ? $lookup(widgets, "multi_shot.storyboard_2_duration") : 0;
+ $d3 := $n >= 3 ? $lookup(widgets, "multi_shot.storyboard_3_duration") : 0;
+ $d4 := $n >= 4 ? $lookup(widgets, "multi_shot.storyboard_4_duration") : 0;
+ $d5 := $n >= 5 ? $lookup(widgets, "multi_shot.storyboard_5_duration") : 0;
+ $d6 := $n >= 6 ? $lookup(widgets, "multi_shot.storyboard_6_duration") : 0;
+ $dur := $isSb ? $d1 + $d2 + $d3 + $d4 + $d5 + $d6 : $lookup(widgets, "multi_shot.duration");
+ {"type":"usd","usd": $rate * $dur}
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ multi_shot: dict,
+ generate_audio: bool,
+ model: dict,
+ seed: int,
+ start_frame: Input.Image | None = None,
+ ) -> IO.NodeOutput:
+ _ = seed
+ mode = "pro" if model["resolution"] == "1080p" else "std"
+ custom_multi_shot = False
+ if multi_shot["multi_shot"] == "disabled":
+ shot_type = None
+ else:
+ shot_type = "customize"
+ custom_multi_shot = True
+
+ multi_prompt_list = None
+ if shot_type == "customize":
+ count = int(multi_shot["multi_shot"].split()[0])
+ multi_prompt_list = []
+ for i in range(1, count + 1):
+ sb_prompt = multi_shot[f"storyboard_{i}_prompt"]
+ sb_duration = multi_shot[f"storyboard_{i}_duration"]
+ validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512)
+ multi_prompt_list.append(
+ MultiPromptEntry(
+ index=i,
+ prompt=sb_prompt,
+ duration=str(sb_duration),
+ )
+ )
+ duration = sum(int(e.duration) for e in multi_prompt_list)
+ if duration < 3 or duration > 15:
+ raise ValueError(
+ f"Total storyboard duration ({duration}s) must be between 3 and 15 seconds."
+ )
+ else:
+ duration = multi_shot["duration"]
+ validate_string(multi_shot["prompt"], min_length=1, max_length=2500)
+
+ if start_frame is not None:
+ validate_image_dimensions(start_frame, min_width=300, min_height=300)
+ validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
+ image_url = await upload_image_to_comfyapi(cls, start_frame, wait_label="Uploading start frame")
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
+ response_model=TaskStatusResponse,
+ data=ImageToVideoWithAudioRequest(
+ model_name=model["model"],
+ image=image_url,
+ prompt=None if custom_multi_shot else multi_shot["prompt"],
+ negative_prompt=None if custom_multi_shot else multi_shot["negative_prompt"],
+ mode=mode,
+ duration=str(duration),
+ sound="on" if generate_audio else "off",
+ multi_shot=True if shot_type else None,
+ multi_prompt=multi_prompt_list,
+ shot_type=shot_type,
+ ),
+ )
+ poll_path = f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"
+ else:
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
+ response_model=TaskStatusResponse,
+ data=TextToVideoWithAudioRequest(
+ model_name=model["model"],
+ aspect_ratio=model["aspect_ratio"],
+ prompt=None if custom_multi_shot else multi_shot["prompt"],
+ negative_prompt=None if custom_multi_shot else multi_shot["negative_prompt"],
+ mode=mode,
+ duration=str(duration),
+ sound="on" if generate_audio else "off",
+ multi_shot=True if shot_type else None,
+ multi_prompt=multi_prompt_list,
+ shot_type=shot_type,
+ ),
+ )
+ poll_path = f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"
+
+ if response.code:
+ raise RuntimeError(
+ f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
+ )
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path=poll_path),
+ response_model=TaskStatusResponse,
+ status_extractor=lambda r: (r.data.task_status if r.data else None),
+ )
+ return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
+
+
+class KlingFirstLastFrameNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingFirstLastFrameNode",
+ display_name="Kling 3.0 First-Last-Frame to Video",
+ category="api node/video/Kling",
+ description="Generate videos with Kling V3 using first and last frames.",
+ inputs=[
+ IO.String.Input("prompt", multiline=True, default=""),
+ IO.Int.Input(
+ "duration",
+ default=5,
+ min=3,
+ max=15,
+ display_mode=IO.NumberDisplay.slider,
+ ),
+ IO.Image.Input("first_frame"),
+ IO.Image.Input("end_frame"),
+ IO.Boolean.Input("generate_audio", default=True),
+ IO.DynamicCombo.Input(
+ "model",
+ options=[
+ IO.DynamicCombo.Option(
+ "kling-v3",
+ [
+ IO.Combo.Input("resolution", options=["1080p", "720p"]),
+ ],
+ ),
+ ],
+ tooltip="Model and generation settings.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(
+ widgets=["model.resolution", "generate_audio", "duration"],
+ ),
+ expr="""
+ (
+ $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}};
+ $res := $lookup(widgets, "model.resolution");
+ $audio := widgets.generate_audio ? "on" : "off";
+ $rate := $lookup($lookup($rates, $res), $audio);
+ {"type":"usd","usd": $rate * widgets.duration}
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ prompt: str,
+ duration: int,
+ first_frame: Input.Image,
+ end_frame: Input.Image,
+ generate_audio: bool,
+ model: dict,
+ seed: int,
+ ) -> IO.NodeOutput:
+ _ = seed
+ validate_string(prompt, min_length=1, max_length=2500)
+ validate_image_dimensions(first_frame, min_width=300, min_height=300)
+ validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
+ validate_image_dimensions(end_frame, min_width=300, min_height=300)
+ validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
+ image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame")
+ image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame")
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
+ response_model=TaskStatusResponse,
+ data=ImageToVideoWithAudioRequest(
+ model_name=model["model"],
+ image=image_url,
+ image_tail=image_tail_url,
+ prompt=prompt,
+ mode="pro" if model["resolution"] == "1080p" else "std",
+ duration=str(duration),
+ sound="on" if generate_audio else "off",
+ ),
+ )
+ if response.code:
+ raise RuntimeError(
+ f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
+ )
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
+ response_model=TaskStatusResponse,
+ status_extractor=lambda r: (r.data.task_status if r.data else None),
+ )
+ return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
+
+
class KlingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -2504,6 +3160,8 @@ class KlingExtension(ComfyExtension):
TextToVideoWithAudio,
ImageToVideoWithAudio,
MotionControl,
+ KlingVideoNode,
+ KlingFirstLastFrameNode,
]
diff --git a/comfy_api_nodes/nodes_magnific.py b/comfy_api_nodes/nodes_magnific.py
index 013e71cc8..83a581c5d 100644
--- a/comfy_api_nodes/nodes_magnific.py
+++ b/comfy_api_nodes/nodes_magnific.py
@@ -30,6 +30,30 @@ from comfy_api_nodes.util import (
validate_image_dimensions,
)
+_EUR_TO_USD = 1.19
+
+
+def _tier_price_eur(megapixels: float) -> float:
+ """Price in EUR for a single Magnific upscaling step based on input megapixels."""
+ if megapixels <= 1.3:
+ return 0.143
+ if megapixels <= 3.0:
+ return 0.286
+ if megapixels <= 6.4:
+ return 0.429
+ return 1.716
+
+
+def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float:
+ """Calculate total Magnific upscale price in USD for given input dimensions and scale factor."""
+ num_steps = int(math.log2(scale))
+ total_eur = 0.0
+ pixels = width * height
+ for _ in range(num_steps):
+ total_eur += _tier_price_eur(pixels / 1_000_000)
+ pixels *= 4
+ return round(total_eur * _EUR_TO_USD, 2)
+
class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
@classmethod
@@ -103,11 +127,20 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
+ depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]),
expr="""
(
- $max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
- {"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
+ $ad := widgets.auto_downscale;
+ $mins := $ad
+ ? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515}
+ : {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
+ $maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187};
+ {
+ "type": "range_usd",
+ "min_usd": $lookup($mins, widgets.scale_factor),
+ "max_usd": $lookup($maxs, widgets.scale_factor),
+ "format": { "approximate": true }
+ }
)
""",
),
@@ -168,6 +201,10 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
f"Use a smaller input image or lower scale factor."
)
+ final_height, final_width = get_image_dimensions(image)
+ actual_scale = int(scale_factor.rstrip("x"))
+ price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale)
+
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
@@ -189,6 +226,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
response_model=TaskResponse,
status_extractor=lambda x: x.status,
+ price_extractor=lambda _: price_usd,
poll_interval=10.0,
max_poll_attempts=480,
)
@@ -257,8 +295,14 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
expr="""
(
- $max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
- {"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
+ $mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
+ $maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06};
+ {
+ "type": "range_usd",
+ "min_usd": $lookup($mins, widgets.scale_factor),
+ "max_usd": $lookup($maxs, widgets.scale_factor),
+ "format": { "approximate": true }
+ }
)
""",
),
@@ -321,6 +365,9 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
f"Use a smaller input image or lower scale factor."
)
+ final_height, final_width = get_image_dimensions(image)
+ price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale)
+
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
@@ -339,6 +386,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
response_model=TaskResponse,
status_extractor=lambda x: x.status,
+ price_extractor=lambda _: price_usd,
poll_interval=10.0,
max_poll_attempts=480,
)
@@ -877,8 +925,8 @@ class MagnificExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
- # MagnificImageUpscalerCreativeNode,
- # MagnificImageUpscalerPreciseV2Node,
+ MagnificImageUpscalerCreativeNode,
+ MagnificImageUpscalerPreciseV2Node,
MagnificImageStyleTransferNode,
MagnificImageRelightNode,
MagnificImageSkinEnhancerNode,
diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py
index 08315fa2b..78a230529 100644
--- a/comfy_api_nodes/nodes_moonvalley.py
+++ b/comfy_api_nodes/nodes_moonvalley.py
@@ -219,8 +219,8 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
- default=33,
- min=1,
+ default=80,
+ min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
max=100,
step=1,
tooltip="Number of denoising steps",
@@ -340,8 +340,8 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
- default=33,
- min=1,
+ default=60,
+ min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24)
max=100,
step=1,
display_mode=IO.NumberDisplay.number,
@@ -370,7 +370,7 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
video: Input.Video | None = None,
control_type: str = "Motion Transfer",
motion_intensity: int | None = 100,
- steps=33,
+ steps=60,
prompt_adherence=4.5,
) -> IO.NodeOutput:
validated_video = validate_video_to_video_input(video)
@@ -465,8 +465,8 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
- default=33,
- min=1,
+ default=80,
+ min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
max=100,
step=1,
tooltip="Inference steps",
diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py
index f05aaab7b..332107a82 100644
--- a/comfy_api_nodes/nodes_openai.py
+++ b/comfy_api_nodes/nodes_openai.py
@@ -43,7 +43,6 @@ class SupportedOpenAIModel(str, Enum):
o1 = "o1"
o3 = "o3"
o1_pro = "o1-pro"
- gpt_4o = "gpt-4o"
gpt_4_1 = "gpt-4.1"
gpt_4_1_mini = "gpt-4.1-mini"
gpt_4_1_nano = "gpt-4.1-nano"
@@ -649,11 +648,6 @@ class OpenAIChatNode(IO.ComfyNode):
"usd": [0.01, 0.04],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
- : $contains($m, "gpt-4o") ? {
- "type": "list_usd",
- "usd": [0.0025, 0.01],
- "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
- }
: $contains($m, "gpt-4.1-nano") ? {
"type": "list_usd",
"usd": [0.0001, 0.0004],
diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py
index 18b020eef..f8a0ba8af 100644
--- a/comfy_api_nodes/util/__init__.py
+++ b/comfy_api_nodes/util/__init__.py
@@ -33,6 +33,7 @@ from .download_helpers import (
download_url_to_video_output,
)
from .upload_helpers import (
+ upload_3d_model_to_comfyapi,
upload_audio_to_comfyapi,
upload_file_to_comfyapi,
upload_image_to_comfyapi,
@@ -62,6 +63,7 @@ __all__ = [
"sync_op",
"sync_op_raw",
# Upload helpers
+ "upload_3d_model_to_comfyapi",
"upload_audio_to_comfyapi",
"upload_file_to_comfyapi",
"upload_image_to_comfyapi",
diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py
index 8a1259506..94886af7b 100644
--- a/comfy_api_nodes/util/client.py
+++ b/comfy_api_nodes/util/client.py
@@ -57,6 +57,7 @@ class _RequestConfig:
files: dict[str, Any] | list[tuple[str, Any]] | None
multipart_parser: Callable | None
max_retries: int
+ max_retries_on_rate_limit: int
retry_delay: float
retry_backoff: float
wait_label: str = "Waiting"
@@ -65,6 +66,7 @@ class _RequestConfig:
final_label_on_success: str | None = "Completed"
progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
+ is_rate_limited: Callable[[int, Any], bool] | None = None
@dataclass
@@ -78,7 +80,7 @@ class _PollUIState:
active_since: float | None = None # start time of current active interval (None if queued)
-_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
+_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
@@ -103,6 +105,8 @@ async def sync_op(
final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None,
monitor_progress: bool = True,
+ max_retries_on_rate_limit: int = 16,
+ is_rate_limited: Callable[[int, Any], bool] | None = None,
) -> M:
raw = await sync_op_raw(
cls,
@@ -122,6 +126,8 @@ async def sync_op(
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
monitor_progress=monitor_progress,
+ max_retries_on_rate_limit=max_retries_on_rate_limit,
+ is_rate_limited=is_rate_limited,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
@@ -143,9 +149,9 @@ async def poll_op(
poll_interval: float = 5.0,
max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0,
- max_retries_per_poll: int = 3,
+ max_retries_per_poll: int = 10,
retry_delay_per_poll: float = 1.0,
- retry_backoff_per_poll: float = 2.0,
+ retry_backoff_per_poll: float = 1.4,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
@@ -194,6 +200,8 @@ async def sync_op_raw(
final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None,
monitor_progress: bool = True,
+ max_retries_on_rate_limit: int = 16,
+ is_rate_limited: Callable[[int, Any], bool] | None = None,
) -> dict[str, Any] | bytes:
"""
Make a single network request.
@@ -222,6 +230,8 @@ async def sync_op_raw(
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
price_extractor=price_extractor,
+ max_retries_on_rate_limit=max_retries_on_rate_limit,
+ is_rate_limited=is_rate_limited,
)
return await _request_base(cfg, expect_binary=as_binary)
@@ -240,9 +250,9 @@ async def poll_op_raw(
poll_interval: float = 5.0,
max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0,
- max_retries_per_poll: int = 3,
+ max_retries_per_poll: int = 10,
retry_delay_per_poll: float = 1.0,
- retry_backoff_per_poll: float = 2.0,
+ retry_backoff_per_poll: float = 1.4,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
@@ -506,7 +516,7 @@ def _friendly_http_message(status: int, body: Any) -> str:
if status == 409:
return "There is a problem with your account. Please contact support@comfy.org."
if status == 429:
- return "Rate Limit Exceeded: Please try again later."
+ return "Rate Limit Exceeded: The server returned 429 after all retry attempts. Please wait and try again."
try:
if isinstance(body, dict):
err = body.get("error")
@@ -586,6 +596,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
attempt = 0
delay = cfg.retry_delay
+ rate_limit_attempts = 0
+ rate_limit_delay = cfg.retry_delay
operation_succeeded: bool = False
final_elapsed_seconds: int | None = None
extracted_price: float | None = None
@@ -653,17 +665,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload_headers["Content-Type"] = "application/json"
payload_kw["json"] = cfg.data or {}
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- request_headers=dict(payload_headers) if payload_headers else None,
- request_params=dict(params) if params else None,
- request_data=request_body_log,
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] request logging failed: %s", _log_e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ request_headers=dict(payload_headers) if payload_headers else None,
+ request_params=dict(params) if params else None,
+ request_data=request_body_log,
+ )
req_coro = sess.request(method, url, params=params, **payload_kw)
req_task = asyncio.create_task(req_coro)
@@ -688,41 +697,33 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
body = await resp.json()
except (ContentTypeError, json.JSONDecodeError):
body = await resp.text()
- if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
+ should_retry = False
+ wait_time = 0.0
+ retry_label = ""
+ is_rl = resp.status == 429 or (
+ cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
+ )
+ if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
+ rate_limit_attempts += 1
+ wait_time = min(rate_limit_delay, 30.0)
+ rate_limit_delay *= cfg.retry_backoff
+ retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
+ should_retry = True
+ elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
+ wait_time = delay
+ delay *= cfg.retry_backoff
+ retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
+ should_retry = True
+
+ if should_retry:
logging.warning(
- "HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
+ "HTTP %s %s -> %s. Waiting %.2fs (%s).",
method,
url,
resp.status,
- delay,
- attempt,
- cfg.max_retries,
+ wait_time,
+ retry_label,
)
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- response_status_code=resp.status,
- response_headers=dict(resp.headers),
- response_content=body,
- error_message=_friendly_http_message(resp.status, body),
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] response logging failed: %s", _log_e)
-
- await sleep_with_interrupt(
- delay,
- cfg.node_cls,
- cfg.wait_label if cfg.monitor_progress else None,
- start_time if cfg.monitor_progress else None,
- cfg.estimated_total,
- display_callback=_display_time_progress if cfg.monitor_progress else None,
- )
- delay *= cfg.retry_backoff
- continue
- msg = _friendly_http_message(resp.status, body)
- try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
@@ -730,10 +731,27 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
- error_message=msg,
+ error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)",
)
- except Exception as _log_e:
- logging.debug("[DEBUG] response logging failed: %s", _log_e)
+ await sleep_with_interrupt(
+ wait_time,
+ cfg.node_cls,
+ cfg.wait_label if cfg.monitor_progress else None,
+ start_time if cfg.monitor_progress else None,
+ cfg.estimated_total,
+ display_callback=_display_time_progress if cfg.monitor_progress else None,
+ )
+ continue
+ msg = _friendly_http_message(resp.status, body)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ response_status_code=resp.status,
+ response_headers=dict(resp.headers),
+ response_content=body,
+ error_message=msg,
+ )
raise Exception(msg)
if expect_binary:
@@ -753,17 +771,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
bytes_payload = bytes(buff)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- response_status_code=resp.status,
- response_headers=dict(resp.headers),
- response_content=bytes_payload,
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] response logging failed: %s", _log_e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ response_status_code=resp.status,
+ response_headers=dict(resp.headers),
+ response_content=bytes_payload,
+ )
return bytes_payload
else:
try:
@@ -780,45 +795,39 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- response_status_code=resp.status,
- response_headers=dict(resp.headers),
- response_content=response_content_to_log,
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] response logging failed: %s", _log_e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ response_status_code=resp.status,
+ response_headers=dict(resp.headers),
+ response_content=response_content_to_log,
+ )
return payload
except ProcessingInterrupted:
logging.debug("Polling was interrupted by user")
raise
except (ClientError, OSError) as e:
- if attempt <= cfg.max_retries:
+ if (attempt - rate_limit_attempts) <= cfg.max_retries:
logging.warning(
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
method,
url,
delay,
- attempt,
+ attempt - rate_limit_attempts,
cfg.max_retries,
str(e),
)
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- request_headers=dict(payload_headers) if payload_headers else None,
- request_params=dict(params) if params else None,
- request_data=request_body_log,
- error_message=f"{type(e).__name__}: {str(e)} (will retry)",
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] request error logging failed: %s", _log_e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ request_headers=dict(payload_headers) if payload_headers else None,
+ request_params=dict(params) if params else None,
+ request_data=request_body_log,
+ error_message=f"{type(e).__name__}: {str(e)} (will retry)",
+ )
await sleep_with_interrupt(
delay,
cfg.node_cls,
@@ -831,23 +840,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
continue
diag = await _diagnose_connectivity()
if not diag["internet_accessible"]:
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- request_headers=dict(payload_headers) if payload_headers else None,
- request_params=dict(params) if params else None,
- request_data=request_body_log,
- error_message=f"LocalNetworkError: {str(e)}",
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] final error logging failed: %s", _log_e)
- raise LocalNetworkError(
- "Unable to connect to the API server due to local network issues. "
- "Please check your internet connection and try again."
- ) from e
- try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
@@ -855,10 +847,21 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
- error_message=f"ApiServerError: {str(e)}",
+ error_message=f"LocalNetworkError: {str(e)}",
)
- except Exception as _log_e:
- logging.debug("[DEBUG] final error logging failed: %s", _log_e)
+ raise LocalNetworkError(
+ "Unable to connect to the API server due to local network issues. "
+ "Please check your internet connection and try again."
+ ) from e
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ request_headers=dict(payload_headers) if payload_headers else None,
+ request_params=dict(params) if params else None,
+ request_data=request_body_log,
+ error_message=f"ApiServerError: {str(e)}",
+ )
raise ApiServerError(
f"The API server at {default_base_url()} is currently unreachable. "
f"The service may be experiencing issues."
diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py
index 3e37e8a8c..82b6d22a5 100644
--- a/comfy_api_nodes/util/conversions.py
+++ b/comfy_api_nodes/util/conversions.py
@@ -57,7 +57,7 @@ def tensor_to_bytesio(
image: torch.Tensor,
*,
total_pixels: int | None = 2048 * 2048,
- mime_type: str = "image/png",
+ mime_type: str | None = "image/png",
) -> BytesIO:
"""Converts a torch.Tensor image to a named BytesIO object.
diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py
index 78bcf1fa1..aa588d038 100644
--- a/comfy_api_nodes/util/download_helpers.py
+++ b/comfy_api_nodes/util/download_helpers.py
@@ -167,27 +167,25 @@ async def download_url_to_bytesio(
with contextlib.suppress(Exception):
dest.seek(0)
- with contextlib.suppress(Exception):
- request_logger.log_request_response(
- operation_id=op_id,
- request_method="GET",
- request_url=url,
- response_status_code=resp.status,
- response_headers=dict(resp.headers),
- response_content=f"[streamed {written} bytes to dest]",
- )
+ request_logger.log_request_response(
+ operation_id=op_id,
+ request_method="GET",
+ request_url=url,
+ response_status_code=resp.status,
+ response_headers=dict(resp.headers),
+ response_content=f"[streamed {written} bytes to dest]",
+ )
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (ClientError, OSError) as e:
if attempt <= max_retries:
- with contextlib.suppress(Exception):
- request_logger.log_request_response(
- operation_id=op_id,
- request_method="GET",
- request_url=url,
- error_message=f"{type(e).__name__}: {str(e)} (will retry)",
- )
+ request_logger.log_request_response(
+ operation_id=op_id,
+ request_method="GET",
+ request_url=url,
+ error_message=f"{type(e).__name__}: {str(e)} (will retry)",
+ )
await sleep_with_interrupt(delay, cls, None, None, None)
delay *= retry_backoff
continue
diff --git a/comfy_api_nodes/util/request_logger.py b/comfy_api_nodes/util/request_logger.py
index e0cb4428d..fe0543d9b 100644
--- a/comfy_api_nodes/util/request_logger.py
+++ b/comfy_api_nodes/util/request_logger.py
@@ -8,7 +8,6 @@ from typing import Any
import folder_paths
-# Get the logger instance
logger = logging.getLogger(__name__)
@@ -91,38 +90,41 @@ def log_request_response(
Filenames are sanitized and length-limited for cross-platform safety.
If we still fail to write, we fall back to appending into api.log.
"""
- log_dir = get_log_directory()
- filepath = _build_log_filepath(log_dir, operation_id, request_url)
-
- log_content: list[str] = []
- log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
- log_content.append(f"Operation ID: {operation_id}")
- log_content.append("-" * 30 + " REQUEST " + "-" * 30)
- log_content.append(f"Method: {request_method}")
- log_content.append(f"URL: {request_url}")
- if request_headers:
- log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
- if request_params:
- log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
- if request_data is not None:
- log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
-
- log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
- if response_status_code is not None:
- log_content.append(f"Status Code: {response_status_code}")
- if response_headers:
- log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
- if response_content is not None:
- log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
- if error_message:
- log_content.append(f"Error:\n{error_message}")
-
try:
- with open(filepath, "w", encoding="utf-8") as f:
- f.write("\n".join(log_content))
- logger.debug("API log saved to: %s", filepath)
- except Exception as e:
- logger.error("Error writing API log to %s: %s", filepath, str(e))
+ log_dir = get_log_directory()
+ filepath = _build_log_filepath(log_dir, operation_id, request_url)
+
+ log_content: list[str] = []
+ log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
+ log_content.append(f"Operation ID: {operation_id}")
+ log_content.append("-" * 30 + " REQUEST " + "-" * 30)
+ log_content.append(f"Method: {request_method}")
+ log_content.append(f"URL: {request_url}")
+ if request_headers:
+ log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
+ if request_params:
+ log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
+ if request_data is not None:
+ log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
+
+ log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
+ if response_status_code is not None:
+ log_content.append(f"Status Code: {response_status_code}")
+ if response_headers:
+ log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
+ if response_content is not None:
+ log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
+ if error_message:
+ log_content.append(f"Error:\n{error_message}")
+
+ try:
+ with open(filepath, "w", encoding="utf-8") as f:
+ f.write("\n".join(log_content))
+ logger.debug("API log saved to: %s", filepath)
+ except Exception as e:
+ logger.error("Error writing API log to %s: %s", filepath, str(e))
+ except Exception as _log_e:
+ logging.debug("[DEBUG] log_request_response failed: %s", _log_e)
if __name__ == '__main__':
diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py
index 83d936ce1..6d1d107a1 100644
--- a/comfy_api_nodes/util/upload_helpers.py
+++ b/comfy_api_nodes/util/upload_helpers.py
@@ -164,6 +164,27 @@ async def upload_video_to_comfyapi(
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
+_3D_MIME_TYPES = {
+ "glb": "model/gltf-binary",
+ "obj": "model/obj",
+ "fbx": "application/octet-stream",
+}
+
+
+async def upload_3d_model_to_comfyapi(
+ cls: type[IO.ComfyNode],
+ model_3d: Types.File3D,
+ file_format: str,
+) -> str:
+ """Uploads a 3D model file to ComfyUI API and returns its download URL."""
+ return await upload_file_to_comfyapi(
+ cls,
+ model_3d.get_data(),
+ f"{uuid.uuid4()}.{file_format}",
+ _3D_MIME_TYPES.get(file_format, "application/octet-stream"),
+ )
+
+
async def upload_file_to_comfyapi(
cls: type[IO.ComfyNode],
file_bytes_io: BytesIO,
@@ -255,17 +276,14 @@ async def upload_file(
monitor_task = asyncio.create_task(_monitor())
sess: aiohttp.ClientSession | None = None
try:
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method="PUT",
- request_url=upload_url,
- request_headers=headers or None,
- request_params=None,
- request_data=f"[File data {len(data)} bytes]",
- )
- except Exception as e:
- logging.debug("[DEBUG] upload request logging failed: %s", e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method="PUT",
+ request_url=upload_url,
+ request_headers=headers or None,
+ request_params=None,
+ request_data=f"[File data {len(data)} bytes]",
+ )
sess = aiohttp.ClientSession(timeout=timeout)
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
@@ -311,31 +329,27 @@ async def upload_file(
delay *= retry_backoff
continue
raise Exception(f"Failed to upload (HTTP {resp.status}).")
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method="PUT",
- request_url=upload_url,
- response_status_code=resp.status,
- response_headers=dict(resp.headers),
- response_content="File uploaded successfully.",
- )
- except Exception as e:
- logging.debug("[DEBUG] upload response logging failed: %s", e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method="PUT",
+ request_url=upload_url,
+ response_status_code=resp.status,
+ response_headers=dict(resp.headers),
+ response_content="File uploaded successfully.",
+ )
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (aiohttp.ClientError, OSError) as e:
if attempt <= max_retries:
- with contextlib.suppress(Exception):
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method="PUT",
- request_url=upload_url,
- request_headers=headers or None,
- request_data=f"[File data {len(data)} bytes]",
- error_message=f"{type(e).__name__}: {str(e)} (will retry)",
- )
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method="PUT",
+ request_url=upload_url,
+ request_headers=headers or None,
+ request_data=f"[File data {len(data)} bytes]",
+ error_message=f"{type(e).__name__}: {str(e)} (will retry)",
+ )
await sleep_with_interrupt(
delay,
cls,
diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py
index bf091a448..370014fb6 100644
--- a/comfy_execution/jobs.py
+++ b/comfy_execution/jobs.py
@@ -20,10 +20,60 @@ class JobStatus:
# Media types that can be previewed in the frontend
-PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
+PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
# 3D file extensions for preview fallback (no dedicated media_type exists)
-THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
+THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
+
+
+def has_3d_extension(filename: str) -> bool:
+ lower = filename.lower()
+ return any(lower.endswith(ext) for ext in THREE_D_EXTENSIONS)
+
+
+def normalize_output_item(item):
+ """Normalize a single output list item for the jobs API.
+
+ Returns the normalized item, or None to exclude it.
+ String items with 3D extensions become {filename, type, subfolder} dicts.
+ """
+ if item is None:
+ return None
+ if isinstance(item, str):
+ if has_3d_extension(item):
+ return {'filename': item, 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
+ return None
+ if isinstance(item, dict):
+ return item
+ return None
+
+
+def normalize_outputs(outputs: dict) -> dict:
+ """Normalize raw node outputs for the jobs API.
+
+ Transforms string 3D filenames into file output dicts and removes
+ None items. All other items (non-3D strings, dicts, etc.) are
+ preserved as-is.
+ """
+ normalized = {}
+ for node_id, node_outputs in outputs.items():
+ if not isinstance(node_outputs, dict):
+ normalized[node_id] = node_outputs
+ continue
+ normalized_node = {}
+ for media_type, items in node_outputs.items():
+ if media_type == 'animated' or not isinstance(items, list):
+ normalized_node[media_type] = items
+ continue
+ normalized_items = []
+ for item in items:
+ if item is None:
+ continue
+ norm = normalize_output_item(item)
+ normalized_items.append(norm if norm is not None else item)
+ normalized_node[media_type] = normalized_items
+ normalized[node_id] = normalized_node
+ return normalized
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
@@ -45,9 +95,9 @@ def is_previewable(media_type: str, item: dict) -> bool:
Maintains backwards compatibility with existing logic.
Priority:
- 1. media_type is 'images', 'video', or 'audio'
+ 1. media_type is 'images', 'video', 'audio', or '3d'
2. format field starts with 'video/' or 'audio/'
- 3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
+ 3. filename has a 3D extension (.obj, .fbx, .gltf, .glb, .usdz)
"""
if media_type in PREVIEWABLE_MEDIA_TYPES:
return True
@@ -139,7 +189,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
})
if include_outputs:
- job['outputs'] = outputs
+ job['outputs'] = normalize_outputs(outputs)
job['execution_status'] = status_info
job['workflow'] = {
'prompt': prompt,
@@ -171,18 +221,23 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
continue
for item in items:
- count += 1
-
- if not isinstance(item, dict):
+ normalized = normalize_output_item(item)
+ if normalized is None:
continue
- if preview_output is None and is_previewable(media_type, item):
+ count += 1
+
+ if preview_output is not None:
+ continue
+
+ if isinstance(normalized, dict) and is_previewable(media_type, normalized):
enriched = {
- **item,
+ **normalized,
'nodeId': node_id,
- 'mediaType': media_type
}
- if item.get('type') == 'output':
+ if 'mediaType' not in normalized:
+ enriched['mediaType'] = media_type
+ if normalized.get('type') == 'output':
preview_output = enriched
elif fallback_preview is None:
fallback_preview = enriched
diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py
index dde5bbd2a..9cf84ab4d 100644
--- a/comfy_extras/nodes_ace.py
+++ b/comfy_extras/nodes_ace.py
@@ -49,13 +49,14 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
+ io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
],
outputs=[io.Conditioning.Output()],
)
@classmethod
- def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
- tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
+ def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
+ tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
conditioning = clip.encode_from_tokens_scheduled(tokens)
return io.NodeOutput(conditioning)
diff --git a/comfy_extras/nodes_lora_extract.py b/comfy_extras/nodes_lora_extract.py
index fb89e03f4..1542d0a88 100644
--- a/comfy_extras/nodes_lora_extract.py
+++ b/comfy_extras/nodes_lora_extract.py
@@ -7,6 +7,7 @@ import logging
from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
+from tqdm.auto import trange
CLAMP_QUANTILE = 0.99
@@ -49,12 +50,22 @@ LORA_TYPES = {"standard": LORAType.STANDARD,
"full_diff": LORAType.FULL_DIFF}
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
- comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
+ comfy.model_management.load_models_gpu([model_diff])
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
- for k in sd:
- if k.endswith(".weight"):
+ sd_keys = list(sd.keys())
+ for index in trange(len(sd_keys), unit="weight"):
+ k = sd_keys[index]
+ op_keys = sd_keys[index].rsplit('.', 1)
+ if len(op_keys) < 2 or op_keys[1] not in ["weight", "bias"] or (op_keys[1] == "bias" and not bias_diff):
+ continue
+ op = comfy.utils.get_attr(model_diff.model, op_keys[0])
+ if hasattr(op, "comfy_cast_weights") and not getattr(op, "comfy_patched_weights", False):
+ weight_diff = model_diff.patch_weight_to_device(k, model_diff.load_device, return_weight=True)
+ else:
weight_diff = sd[k]
+
+ if op_keys[1] == "weight":
if lora_type == LORAType.STANDARD:
if weight_diff.ndim < 2:
if bias_diff:
@@ -69,8 +80,8 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
elif lora_type == LORAType.FULL_DIFF:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
- elif bias_diff and k.endswith(".bias"):
- output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
+ elif bias_diff and op_keys[1] == "bias":
+ output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = weight_diff.contiguous().half().cpu()
return output_sd
class LoraSave(io.ComfyNode):
diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py
index a52a90e2c..66dac10b1 100644
--- a/comfy_extras/nodes_post_processing.py
+++ b/comfy_extras/nodes_post_processing.py
@@ -655,6 +655,7 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
batched = batch_masks(values)
return io.NodeOutput(batched)
+
class PostProcessingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
diff --git a/comfy_extras/nodes_replacements.py b/comfy_extras/nodes_replacements.py
new file mode 100644
index 000000000..7684e854c
--- /dev/null
+++ b/comfy_extras/nodes_replacements.py
@@ -0,0 +1,103 @@
+from comfy_api.latest import ComfyExtension, io, ComfyAPI
+
+api = ComfyAPI()
+
+
+async def register_replacements():
+ """Register all built-in node replacements."""
+ await register_replacements_longeredge()
+ await register_replacements_batchimages()
+ await register_replacements_upscaleimage()
+ await register_replacements_controlnet()
+ await register_replacements_load3d()
+ await register_replacements_preview3d()
+ await register_replacements_svdimg2vid()
+ await register_replacements_conditioningavg()
+
+async def register_replacements_longeredge():
+ # No dynamic inputs here
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="ImageScaleToMaxDimension",
+ old_node_id="ResizeImagesByLongerEdge",
+ old_widget_ids=["longer_edge"],
+ input_mapping=[
+ {"new_id": "image", "old_id": "images"},
+ {"new_id": "largest_size", "old_id": "longer_edge"},
+ {"new_id": "upscale_method", "set_value": "lanczos"},
+ ],
+ # just to test the frontend output_mapping code, does nothing really here
+ output_mapping=[{"new_idx": 0, "old_idx": 0}],
+ ))
+
+async def register_replacements_batchimages():
+ # BatchImages node uses Autogrow
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="BatchImagesNode",
+ old_node_id="ImageBatch",
+ input_mapping=[
+ {"new_id": "images.image0", "old_id": "image1"},
+ {"new_id": "images.image1", "old_id": "image2"},
+ ],
+ ))
+
+async def register_replacements_upscaleimage():
+ # ResizeImageMaskNode uses DynamicCombo
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="ResizeImageMaskNode",
+ old_node_id="ImageScaleBy",
+ old_widget_ids=["upscale_method", "scale_by"],
+ input_mapping=[
+ {"new_id": "input", "old_id": "image"},
+ {"new_id": "resize_type", "set_value": "scale by multiplier"},
+ {"new_id": "resize_type.multiplier", "old_id": "scale_by"},
+ {"new_id": "scale_method", "old_id": "upscale_method"},
+ ],
+ ))
+
+async def register_replacements_controlnet():
+ # T2IAdapterLoader → ControlNetLoader
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="ControlNetLoader",
+ old_node_id="T2IAdapterLoader",
+ input_mapping=[
+ {"new_id": "control_net_name", "old_id": "t2i_adapter_name"},
+ ],
+ ))
+
+async def register_replacements_load3d():
+ # Load3DAnimation merged into Load3D
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="Load3D",
+ old_node_id="Load3DAnimation",
+ ))
+
+async def register_replacements_preview3d():
+ # Preview3DAnimation merged into Preview3D
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="Preview3D",
+ old_node_id="Preview3DAnimation",
+ ))
+
+async def register_replacements_svdimg2vid():
+ # Typo fix: SDV → SVD
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="SVD_img2vid_Conditioning",
+ old_node_id="SDV_img2vid_Conditioning",
+ ))
+
+async def register_replacements_conditioningavg():
+ # Typo fix: trailing space in node name
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="ConditioningAverage",
+ old_node_id="ConditioningAverage ",
+ ))
+
+class NodeReplacementsExtension(ComfyExtension):
+ async def on_load(self) -> None:
+ await register_replacements()
+
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return []
+
+async def comfy_entrypoint() -> NodeReplacementsExtension:
+ return NodeReplacementsExtension()
diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py
index 024a89391..aa2d88673 100644
--- a/comfy_extras/nodes_train.py
+++ b/comfy_extras/nodes_train.py
@@ -4,6 +4,7 @@ import os
import numpy as np
import safetensors
import torch
+import torch.nn as nn
import torch.utils.checkpoint
from tqdm.auto import trange
from PIL import Image, ImageDraw, ImageFont
@@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
"""
CFGGuider with modifications for training specific logic
"""
+
+ def __init__(self, *args, offloading=False, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.offloading = offloading
+
def outer_sample(
self,
noise,
@@ -45,9 +51,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
noise.shape,
self.conds,
self.model_options,
- force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
+ force_full_load=not self.offloading,
+ force_offload=self.offloading,
)
)
+ torch.cuda.empty_cache()
device = self.model_patcher.load_device
if denoise_mask is not None:
@@ -404,16 +412,97 @@ def find_all_highest_child_module_with_forward(
return result
-def patch(m):
+def find_modules_at_depth(
+ model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
+) -> list[nn.Module]:
+ """
+ Find modules at a specific depth level for gradient checkpointing.
+
+ Args:
+ model: The model to search
+ depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
+ result: Accumulator for results
+ current_depth: Current recursion depth
+ name: Current module name for logging
+
+ Returns:
+ List of modules at the target depth
+ """
+ if result is None:
+ result = []
+ name = name or "root"
+
+ # Skip container modules (they don't have meaningful forward)
+ is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
+ has_forward = hasattr(model, "forward") and not is_container
+
+ if has_forward:
+ current_depth += 1
+ if current_depth == depth:
+ result.append(model)
+ logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
+ return result
+
+ # Recurse into children
+ for next_name, child in model.named_children():
+ find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
+
+ return result
+
+
+class OffloadCheckpointFunction(torch.autograd.Function):
+ """
+ Gradient checkpointing that works with weight offloading.
+
+ Forward: no_grad -> compute -> weights can be freed
+ Backward: enable_grad -> recompute -> backward -> weights can be freed
+
+ For single input, single output modules (Linear, Conv*).
+ """
+
+ @staticmethod
+ def forward(ctx, x: torch.Tensor, forward_fn):
+ ctx.save_for_backward(x)
+ ctx.forward_fn = forward_fn
+ with torch.no_grad():
+ return forward_fn(x)
+
+ @staticmethod
+ def backward(ctx, grad_out: torch.Tensor):
+ x, = ctx.saved_tensors
+ forward_fn = ctx.forward_fn
+
+ # Clear context early
+ ctx.forward_fn = None
+
+ with torch.enable_grad():
+ x_detached = x.detach().requires_grad_(True)
+ y = forward_fn(x_detached)
+ y.backward(grad_out)
+ grad_x = x_detached.grad
+
+ # Explicit cleanup
+ del y, x_detached, forward_fn
+
+ return grad_x, None
+
+
+def patch(m, offloading=False):
if not hasattr(m, "forward"):
return
org_forward = m.forward
- def fwd(args, kwargs):
- return org_forward(*args, **kwargs)
+ # Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
+ if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ def checkpointing_fwd(x):
+ return OffloadCheckpointFunction.apply(x, org_forward)
+ # Branch 2: Others -> standard checkpoint
+ else:
+ def fwd(args, kwargs):
+ return org_forward(*args, **kwargs)
- def checkpointing_fwd(*args, **kwargs):
- return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
+ def checkpointing_fwd(*args, **kwargs):
+ return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
m.org_forward = org_forward
m.forward = checkpointing_fwd
@@ -936,6 +1025,18 @@ class TrainLoraNode(io.ComfyNode):
default=True,
tooltip="Use gradient checkpointing for training.",
),
+ io.Int.Input(
+ "checkpoint_depth",
+ default=1,
+ min=1,
+ max=5,
+ tooltip="Depth level for gradient checkpointing.",
+ ),
+ io.Boolean.Input(
+ "offloading",
+ default=False,
+ tooltip="Offload the Model to RAM. Requires Bypass Mode.",
+ ),
io.Combo.Input(
"existing_lora",
options=folder_paths.get_filename_list("loras") + ["[None]"],
@@ -982,6 +1083,8 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype,
algorithm,
gradient_checkpointing,
+ checkpoint_depth,
+ offloading,
existing_lora,
bucket_mode,
bypass_mode,
@@ -1000,6 +1103,8 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = lora_dtype[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
+ offloading = offloading[0]
+ checkpoint_depth = checkpoint_depth[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0]
@@ -1019,6 +1124,15 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
+ if mp.is_dynamic():
+ if not bypass_mode:
+ logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
+ bypass_mode = True
+ offloading = True
+ elif offloading:
+ if not bypass_mode:
+ logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
+
# Prepare latents and compute counts
latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode
@@ -1054,16 +1168,18 @@ class TrainLoraNode(io.ComfyNode):
# Setup gradient checkpointing
if gradient_checkpointing:
- for m in find_all_highest_child_module_with_forward(
- mp.model.diffusion_model
- ):
- patch(m)
+ modules_to_patch = find_modules_at_depth(
+ mp.model.diffusion_model, depth=checkpoint_depth
+ )
+ logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
+ for m in modules_to_patch:
+ patch(m, offloading=offloading)
torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu(
- [mp], memory_required=1e20, force_full_load=True
+ [mp], memory_required=1e20, force_full_load=not offloading
)
torch.cuda.empty_cache()
@@ -1100,7 +1216,7 @@ class TrainLoraNode(io.ComfyNode):
)
# Setup guider
- guider = TrainGuider(mp)
+ guider = TrainGuider(mp, offloading=offloading)
guider.set_conds(positive)
# Inject bypass hooks if bypass mode is enabled
@@ -1113,6 +1229,7 @@ class TrainLoraNode(io.ComfyNode):
# Run training loop
try:
+ comfy.model_management.in_training = True
_run_training_loop(
guider,
train_sampler,
@@ -1123,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode):
multi_res,
)
finally:
+ comfy.model_management.in_training = False
# Eject bypass hooks if they were injected
if bypass_injections is not None:
for injection in bypass_injections:
@@ -1132,19 +1250,20 @@ class TrainLoraNode(io.ComfyNode):
unpatch(m)
del train_sampler, optimizer
- # Finalize adapters
+ for param in lora_sd:
+ lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
+
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
-
- for param in lora_sd:
- lora_sd[param] = lora_sd[param].to(lora_dtype)
+ del adapter
+ del all_weight_adapters
# mp in train node is highly specialized for training
# use it in inference will result in bad behavior so we don't return it
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
-class LoraModelLoader(io.ComfyNode):#
+class LoraModelLoader(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
@@ -1166,6 +1285,11 @@ class LoraModelLoader(io.ComfyNode):#
max=100.0,
tooltip="How strongly to modify the diffusion model. This value can be negative.",
),
+ io.Boolean.Input(
+ "bypass",
+ default=False,
+ tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
+ ),
],
outputs=[
io.Model.Output(
@@ -1175,13 +1299,18 @@ class LoraModelLoader(io.ComfyNode):#
)
@classmethod
- def execute(cls, model, lora, strength_model):
+ def execute(cls, model, lora, strength_model, bypass=False):
if strength_model == 0:
return io.NodeOutput(model)
- model_lora, _ = comfy.sd.load_lora_for_models(
- model, None, lora, strength_model, 0
- )
+ if bypass:
+ model_lora, _ = comfy.sd.load_bypass_lora_for_models(
+ model, None, lora, strength_model, 0
+ )
+ else:
+ model_lora, _ = comfy.sd.load_lora_for_models(
+ model, None, lora, strength_model, 0
+ )
return io.NodeOutput(model_lora)
diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py
index ccf7b63d3..cd765a7c1 100644
--- a/comfy_extras/nodes_video.py
+++ b/comfy_extras/nodes_video.py
@@ -202,6 +202,56 @@ class LoadVideo(io.ComfyNode):
return True
+class VideoSlice(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="Video Slice",
+ display_name="Video Slice",
+ search_aliases=[
+ "trim video duration",
+ "skip first frames",
+ "frame load cap",
+ "start time",
+ ],
+ category="image/video",
+ inputs=[
+ io.Video.Input("video"),
+ io.Float.Input(
+ "start_time",
+ default=0.0,
+ max=1e5,
+ min=-1e5,
+ step=0.001,
+ tooltip="Start time in seconds",
+ ),
+ io.Float.Input(
+ "duration",
+ default=0.0,
+ min=0.0,
+ step=0.001,
+ tooltip="Duration in seconds, or 0 for unlimited duration",
+ ),
+ io.Boolean.Input(
+ "strict_duration",
+ default=False,
+ tooltip="If True, when the specified duration is not possible, an error will be raised.",
+ ),
+ ],
+ outputs=[
+ io.Video.Output(),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, video: io.Video.Type, start_time: float, duration: float, strict_duration: bool) -> io.NodeOutput:
+ trimmed = video.as_trimmed(start_time, duration, strict_duration=strict_duration)
+ if trimmed is not None:
+ return io.NodeOutput(trimmed)
+ raise ValueError(
+ f"Failed to slice video:\nSource duration: {video.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}"
+ )
+
class VideoExtension(ComfyExtension):
@override
@@ -212,6 +262,7 @@ class VideoExtension(ComfyExtension):
CreateVideo,
GetVideoComponents,
LoadVideo,
+ VideoSlice,
]
async def comfy_entrypoint() -> VideoExtension:
diff --git a/comfyui_version.py b/comfyui_version.py
index 706b37763..cf4e89816 100644
--- a/comfyui_version.py
+++ b/comfyui_version.py
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
-__version__ = "0.12.3"
+__version__ = "0.13.0"
diff --git a/execution.py b/execution.py
index 3dbab82e6..f549a2f0f 100644
--- a/execution.py
+++ b/execution.py
@@ -13,8 +13,11 @@ from contextlib import nullcontext
import torch
+from comfy.cli_args import args
import comfy.memory_management
import comfy.model_management
+import comfy_aimdo.model_vbar
+
from latent_preview import set_preview_method
import nodes
from comfy_execution.caching import (
@@ -527,8 +530,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
finally:
if allocator is not None:
+ if args.verbose == "DEBUG":
+ comfy_aimdo.model_vbar.vbars_analyze()
comfy.model_management.reset_cast_buffers()
- torch.cuda.synchronize()
+ comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
@@ -618,6 +623,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
+ elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type:
+ tips = "\n\nTIPS: If you have any \"Load CLIP\" or \"*CLIP Loader\" nodes in your workflow connected to this sampler node make sure the correct file(s) and type is selected."
error_details = {
"node_id": real_node_id,
diff --git a/nodes.py b/nodes.py
index 91de7a9d7..db5f98408 100644
--- a/nodes.py
+++ b/nodes.py
@@ -2264,6 +2264,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
if not isinstance(extension, ComfyExtension):
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
return False
+ await extension.on_load()
node_list = await extension.get_node_list()
if not isinstance(node_list, list):
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
@@ -2435,6 +2436,7 @@ async def init_builtin_extra_nodes():
"nodes_lora_debug.py",
"nodes_color.py",
"nodes_toolkit.py",
+ "nodes_replacements.py",
]
import_failed = []
diff --git a/pyproject.toml b/pyproject.toml
index f7925b92a..9dab9a50c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
-version = "0.12.3"
+version = "0.13.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"
diff --git a/requirements.txt b/requirements.txt
index 5e34a2a49..e939e486a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
-comfyui-frontend-package==1.38.13
-comfyui-workflow-templates==0.8.31
+comfyui-frontend-package==1.38.14
+comfyui-workflow-templates==0.8.38
comfyui-embedded-docs==0.4.1
torch
torchsde
@@ -22,7 +22,7 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.7
-comfy-aimdo>=0.1.7
+comfy-aimdo>=0.1.8
requests
#non essential dependencies:
diff --git a/server.py b/server.py
index 2300393b2..8882e43c4 100644
--- a/server.py
+++ b/server.py
@@ -40,6 +40,7 @@ from app.user_manager import UserManager
from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from app.subgraph_manager import SubgraphManager
+from app.node_replace_manager import NodeReplaceManager
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes
@@ -204,6 +205,7 @@ class PromptServer():
self.model_file_manager = ModelFileManager()
self.custom_node_manager = CustomNodeManager()
self.subgraph_manager = SubgraphManager()
+ self.node_replace_manager = NodeReplaceManager()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = execution.PromptQueue(self)
@@ -887,6 +889,8 @@ class PromptServer():
if "partial_execution_targets" in json_data:
partial_execution_targets = json_data["partial_execution_targets"]
+ self.node_replace_manager.apply_replacements(prompt)
+
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
extra_data = {}
if "extra_data" in json_data:
@@ -995,6 +999,7 @@ class PromptServer():
self.model_file_manager.add_routes(self.routes)
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
+ self.node_replace_manager.add_routes(self.routes)
self.app.add_subapp('/internal', self.internal_routes.get_app())
# Prefix every route with /api for easier matching for delegation.
diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py
index 4d2f9ed36..83c36fe48 100644
--- a/tests/execution/test_jobs.py
+++ b/tests/execution/test_jobs.py
@@ -5,8 +5,11 @@ from comfy_execution.jobs import (
is_previewable,
normalize_queue_item,
normalize_history_item,
+ normalize_output_item,
+ normalize_outputs,
get_outputs_summary,
apply_sorting,
+ has_3d_extension,
)
@@ -35,8 +38,8 @@ class TestIsPreviewable:
"""Unit tests for is_previewable()"""
def test_previewable_media_types(self):
- """Images, video, audio media types should be previewable."""
- for media_type in ['images', 'video', 'audio']:
+ """Images, video, audio, 3d media types should be previewable."""
+ for media_type in ['images', 'video', 'audio', '3d']:
assert is_previewable(media_type, {}) is True
def test_non_previewable_media_types(self):
@@ -46,7 +49,7 @@ class TestIsPreviewable:
def test_3d_extensions_previewable(self):
"""3D file extensions should be previewable regardless of media_type."""
- for ext in ['.obj', '.fbx', '.gltf', '.glb']:
+ for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
item = {'filename': f'model{ext}'}
assert is_previewable('files', item) is True
@@ -160,7 +163,7 @@ class TestGetOutputsSummary:
def test_3d_files_previewable(self):
"""3D file extensions should be previewable."""
- for ext in ['.obj', '.fbx', '.gltf', '.glb']:
+ for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
outputs = {
'node1': {
'files': [{'filename': f'model{ext}', 'type': 'output'}]
@@ -192,6 +195,64 @@ class TestGetOutputsSummary:
assert preview['mediaType'] == 'images'
assert preview['subfolder'] == 'outputs'
+ def test_string_3d_filename_creates_preview(self):
+ """String items with 3D extensions should synthesize a preview (Preview3D node output).
+ Only the .glb counts — nulls and non-file strings are excluded."""
+ outputs = {
+ 'node1': {
+ 'result': ['preview3d_abc123.glb', None, None]
+ }
+ }
+ count, preview = get_outputs_summary(outputs)
+ assert count == 1
+ assert preview is not None
+ assert preview['filename'] == 'preview3d_abc123.glb'
+ assert preview['mediaType'] == '3d'
+ assert preview['nodeId'] == 'node1'
+ assert preview['type'] == 'output'
+
+ def test_string_non_3d_filename_no_preview(self):
+ """String items without 3D extensions should not create a preview."""
+ outputs = {
+ 'node1': {
+ 'result': ['data.json', None]
+ }
+ }
+ count, preview = get_outputs_summary(outputs)
+ assert count == 0
+ assert preview is None
+
+ def test_string_3d_filename_used_as_fallback(self):
+ """String 3D preview should be used when no dict items are previewable."""
+ outputs = {
+ 'node1': {
+ 'latents': [{'filename': 'latent.safetensors'}],
+ },
+ 'node2': {
+ 'result': ['model.glb', None]
+ }
+ }
+ count, preview = get_outputs_summary(outputs)
+ assert preview is not None
+ assert preview['filename'] == 'model.glb'
+ assert preview['mediaType'] == '3d'
+
+
+class TestHas3DExtension:
+ """Unit tests for has_3d_extension()"""
+
+ def test_recognized_extensions(self):
+ for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
+ assert has_3d_extension(f'model{ext}') is True
+
+ def test_case_insensitive(self):
+ assert has_3d_extension('MODEL.GLB') is True
+ assert has_3d_extension('Scene.GLTF') is True
+
+ def test_non_3d_extensions(self):
+ for name in ['photo.png', 'video.mp4', 'data.json', 'model']:
+ assert has_3d_extension(name) is False
+
class TestApplySorting:
"""Unit tests for apply_sorting()"""
@@ -395,3 +456,142 @@ class TestNormalizeHistoryItem:
'prompt': {'nodes': {'1': {}}},
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
}
+
+ def test_include_outputs_normalizes_3d_strings(self):
+ """Detail view should transform string 3D filenames into file output dicts."""
+ history_item = {
+ 'prompt': (
+ 5,
+ 'prompt-3d',
+ {'nodes': {}},
+ {'create_time': 1234567890},
+ ['node1'],
+ ),
+ 'status': {'status_str': 'success', 'completed': True, 'messages': []},
+ 'outputs': {
+ 'node1': {
+ 'result': ['preview3d_abc123.glb', None, None]
+ }
+ },
+ }
+ job = normalize_history_item('prompt-3d', history_item, include_outputs=True)
+
+ assert job['outputs_count'] == 1
+ result_items = job['outputs']['node1']['result']
+ assert len(result_items) == 1
+ assert result_items[0] == {
+ 'filename': 'preview3d_abc123.glb',
+ 'type': 'output',
+ 'subfolder': '',
+ 'mediaType': '3d',
+ }
+
+ def test_include_outputs_preserves_dict_items(self):
+ """Detail view normalization should pass dict items through unchanged."""
+ history_item = {
+ 'prompt': (
+ 5,
+ 'prompt-img',
+ {'nodes': {}},
+ {'create_time': 1234567890},
+ ['node1'],
+ ),
+ 'status': {'status_str': 'success', 'completed': True, 'messages': []},
+ 'outputs': {
+ 'node1': {
+ 'images': [
+ {'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
+ ]
+ }
+ },
+ }
+ job = normalize_history_item('prompt-img', history_item, include_outputs=True)
+
+ assert job['outputs_count'] == 1
+ assert job['outputs']['node1']['images'] == [
+ {'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
+ ]
+
+
+class TestNormalizeOutputItem:
+ """Unit tests for normalize_output_item()"""
+
+ def test_none_returns_none(self):
+ assert normalize_output_item(None) is None
+
+ def test_string_3d_extension_synthesizes_dict(self):
+ result = normalize_output_item('model.glb')
+ assert result == {'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
+
+ def test_string_non_3d_extension_returns_none(self):
+ assert normalize_output_item('data.json') is None
+
+ def test_string_no_extension_returns_none(self):
+ assert normalize_output_item('camera_info_string') is None
+
+ def test_dict_passes_through(self):
+ item = {'filename': 'test.png', 'type': 'output'}
+ assert normalize_output_item(item) is item
+
+ def test_other_types_return_none(self):
+ assert normalize_output_item(42) is None
+ assert normalize_output_item(True) is None
+
+
+class TestNormalizeOutputs:
+ """Unit tests for normalize_outputs()"""
+
+ def test_empty_outputs(self):
+ assert normalize_outputs({}) == {}
+
+ def test_dict_items_pass_through(self):
+ outputs = {
+ 'node1': {
+ 'images': [{'filename': 'a.png', 'type': 'output'}],
+ }
+ }
+ result = normalize_outputs(outputs)
+ assert result == outputs
+
+ def test_3d_string_synthesized(self):
+ outputs = {
+ 'node1': {
+ 'result': ['model.glb', None, None],
+ }
+ }
+ result = normalize_outputs(outputs)
+ assert result == {
+ 'node1': {
+ 'result': [
+ {'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'},
+ ],
+ }
+ }
+
+ def test_animated_key_preserved(self):
+ outputs = {
+ 'node1': {
+ 'images': [{'filename': 'a.png', 'type': 'output'}],
+ 'animated': [True],
+ }
+ }
+ result = normalize_outputs(outputs)
+ assert result['node1']['animated'] == [True]
+
+ def test_non_dict_node_outputs_preserved(self):
+ outputs = {'node1': 'unexpected_value'}
+ result = normalize_outputs(outputs)
+ assert result == {'node1': 'unexpected_value'}
+
+ def test_none_items_filtered_but_other_types_preserved(self):
+ outputs = {
+ 'node1': {
+ 'result': ['data.json', None, [1, 2, 3]],
+ }
+ }
+ result = normalize_outputs(outputs)
+ assert result == {
+ 'node1': {
+ 'result': ['data.json', [1, 2, 3]],
+ }
+ }