Merge branch 'Comfy-Org:master' into master

This commit is contained in:
azazeal04 2026-06-02 16:14:22 +02:00 committed by GitHub
commit 54e1802b0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
69 changed files with 3891 additions and 386 deletions

View File

@ -105,7 +105,7 @@ class WindowAttention(nn.Module):
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view( relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww relative_position_bias = comfy.ops.cast_to_input(relative_position_bias.permute(2, 0, 1).contiguous(), attn) # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0) attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None: if mask is not None:

View File

@ -9,6 +9,7 @@ import comfy.model_management
import comfy.utils import comfy.utils
import comfy.clip_model import comfy.clip_model
import comfy.image_encoders.dino2 import comfy.image_encoders.dino2
import comfy.image_encoders.dino3
class Output: class Output:
def __getitem__(self, key): def __getitem__(self, key):
@ -23,12 +24,16 @@ IMAGE_ENCODERS = {
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection, "siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model, "dinov2": comfy.image_encoders.dino2.Dinov2Model,
"dinov3": comfy.image_encoders.dino3.DINOv3ViTModel,
} }
class ClipVisionModel(): class ClipVisionModel():
def __init__(self, json_config): def __init__(self, json_config):
with open(json_config) as f: if isinstance(json_config, dict):
config = json.load(f) config = json_config
else:
with open(json_config) as f:
config = json.load(f)
self.image_size = config.get("image_size", 224) self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
@ -134,6 +139,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json") json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
elif 'encoder.layer.23.layer_scale2.lambda1' in sd: elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json") json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
elif 'layer.0.mlp.gate_proj.weight' in sd and 'layer.31.norm1.weight' in sd: # Dinov3 ViT-H/16+ (SwiGLU gated MLP, 32 layers)
json_config = comfy.image_encoders.dino3.DINOV3_VITH_CONFIG
else: else:
return None return None

View File

@ -0,0 +1,259 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
# DINOv3 ViT-H/16+ (SwiGLU)
DINOV3_VITH_CONFIG = {
"model_type": "dinov3",
"num_hidden_layers": 32,
"hidden_size": 1280,
"num_attention_heads": 20,
"num_register_tokens": 4,
"intermediate_size": 5120,
"layer_norm_eps": 1e-5,
"num_channels": 3,
"patch_size": 16,
"rope_theta": 100.0,
"use_gated_mlp": True,
"gated_mlp_act": "silu",
"image_size": 1024,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225],
}
class DINOv3ViTMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
self.act_fn = torch.nn.GELU()
def forward(self, x):
return self.down_proj(self.act_fn(self.up_proj(x)))
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, **kwargs):
num_tokens = q.shape[-2]
num_patches = sin.shape[-2]
num_prefix_tokens = num_tokens - num_patches
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
return q, k
class DINOv3ViTAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, device, dtype, operations):
super().__init__()
self.embed_dim = hidden_size
self.num_heads = num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
def forward(self, hidden_states, attention_mask=None, position_embeddings=None, **kwargs):
batch_size, patches, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
attn = optimized_attention_for_device(query_states.device, mask=False)
attn_output = attn(
query_states, key_states, value_states, self.num_heads, attention_mask,
skip_reshape=True, skip_output_reshape=True, low_precision_attention=False,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
class DINOv3ViTGatedMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations, act="silu"):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
self.act_fn = torch.nn.SiLU() if act == "silu" else torch.nn.GELU()
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def get_patches_center_coordinates(num_patches_h, num_patches_w, dtype, device):
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
coords_h = coords_h / num_patches_h
coords_w = coords_w / num_patches_w
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
coords = coords.flatten(0, 1)
coords = 2.0 * coords - 1.0
return coords
class DINOv3ViTRopePositionEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, rope_theta, hidden_size, num_attention_heads, patch_size, device, dtype):
super().__init__()
self.base = rope_theta
self.head_dim = hidden_size // num_attention_heads
self.patch_size = patch_size
inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, pixel_values):
_, _, height, width = pixel_values.shape
num_patches_h = height // self.patch_size
num_patches_w = width // self.patch_size
patch_coords = get_patches_center_coordinates(num_patches_h, num_patches_w, dtype=torch.float32, device=pixel_values.device)
self.inv_freq = self.inv_freq.to(pixel_values.device)
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
angles = angles.flatten(1, 2)
angles = angles.tile(2)
cos = torch.cos(angles).to(dtype=pixel_values.dtype)
sin = torch.sin(angles).to(dtype=pixel_values.dtype)
return cos, sin
class DINOv3ViTEmbeddings(nn.Module):
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations):
super().__init__()
self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
self.mask_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype))
self.patch_embeddings = operations.Conv2d(
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
)
def forward(self, pixel_values, bool_masked_pos=None):
batch_size = pixel_values.shape[0]
patch_embeddings = self.patch_embeddings(pixel_values)
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
if bool_masked_pos is not None:
mask_token = comfy.ops.cast_to_input(self.mask_token, patch_embeddings)
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
cls_token = comfy.ops.cast_to_input(self.cls_token.expand(batch_size, -1, -1), patch_embeddings)
register_tokens = comfy.ops.cast_to_input(self.register_tokens.expand(batch_size, -1, -1), patch_embeddings)
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
return embeddings
class DINOv3ViTLayer(nn.Module):
def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size,
num_attention_heads, device, dtype, operations, gated_mlp_act="silu"):
super().__init__()
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations)
self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
if use_gated_mlp:
self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations, act=gated_mlp_act)
else:
self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations)
self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.attention(hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings)
hidden_states = self.layer_scale1(hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.layer_scale2(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class DINOv3ViTModel(nn.Module):
def __init__(self, config, dtype, device, operations):
super().__init__()
num_hidden_layers = config["num_hidden_layers"]
hidden_size = config["hidden_size"]
num_attention_heads = config["num_attention_heads"]
num_register_tokens = config["num_register_tokens"]
intermediate_size = config["intermediate_size"]
layer_norm_eps = config["layer_norm_eps"]
num_channels = config["num_channels"]
patch_size = config["patch_size"]
rope_theta = config["rope_theta"]
use_gated_mlp = config.get("use_gated_mlp", False)
gated_mlp_act = config.get("gated_mlp_act", "silu")
self.embeddings = DINOv3ViTEmbeddings(
hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size,
dtype=dtype, device=device, operations=operations
)
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(
rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device
)
self.layer = nn.ModuleList([
DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=use_gated_mlp, mlp_bias=True,
intermediate_size=intermediate_size, num_attention_heads=num_attention_heads,
dtype=dtype, device=device, operations=operations, gated_mlp_act=gated_mlp_act)
for _ in range(num_hidden_layers)])
self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def forward(self, pixel_values, bool_masked_pos=None, **kwargs):
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
position_embeddings = self.rope_embeddings(pixel_values)
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, position_embeddings=position_embeddings)
if kwargs.get("skip_norm_elementwise", False):
sequence_output = F.layer_norm(hidden_states, hidden_states.shape[-1:])
else:
norm = self.norm.to(hidden_states.device)
sequence_output = norm(hidden_states)
pooled_output = sequence_output[:, 0, :]
return sequence_output, None, pooled_output, None

View File

@ -239,6 +239,16 @@ class Flux2(LatentFormat):
def process_out(self, latent): def process_out(self, latent):
return latent return latent
class TripoSplat(LatentFormat):
# Sequence latent (B, 8192, 16) the camera token rides alongside as a second nested latent
latent_channels = 16
def process_in(self, latent):
return latent
def process_out(self, latent):
return latent
class Mochi(LatentFormat): class Mochi(LatentFormat):
latent_channels = 12 latent_channels = 12
latent_dimensions = 3 latent_dimensions = 3

View File

@ -38,6 +38,8 @@ class ChromaRadianceParams(ChromaParams):
# None means use the same dtype as the model. # None means use the same dtype as the model.
nerf_embedder_dtype: Optional[torch.dtype] nerf_embedder_dtype: Optional[torch.dtype]
use_x0: bool use_x0: bool
# Use sequential txt_ids instead of zeros
use_sequential_txt_ids: bool
class ChromaRadiance(Chroma): class ChromaRadiance(Chroma):
""" """
@ -162,6 +164,9 @@ class ChromaRadiance(Chroma):
if params.use_x0: if params.use_x0:
self.register_buffer("__x0__", torch.tensor([])) self.register_buffer("__x0__", torch.tensor([]))
if params.use_sequential_txt_ids:
self.register_buffer("__sequential__", torch.tensor([]))
@property @property
def _nerf_final_layer(self) -> nn.Module: def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear": if self.params.nerf_final_head_type == "linear":
@ -313,6 +318,9 @@ class ChromaRadiance(Chroma):
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
# Radiance after 2026-05-22 uses sequential txt_ids instead of zeros
if params.use_sequential_txt_ids:
txt_ids[:, :, 0] = torch.arange(context.shape[1], device=x.device, dtype=x.dtype).unsqueeze(0).expand(bs, -1)
img_out = self.forward_orig( img_out = self.forward_orig(
img, img,

View File

@ -0,0 +1,199 @@
# TripoSplat 3D gaussian container. Operates on already-decoded
# tensors and exposes them as render-ready tensors (render_tensors) for the generic SPLAT type.
import torch
import torch.nn.functional as F
import comfy.model_management
class GaussianModel:
def __init__(self, aabb: list, sh_degree: int = 0, mininum_kernel_size: float = 0.0,
scaling_bias: float = 0.01, opacity_bias: float = 0.1,
scaling_activation: str = "exp", device=None):
self.sh_degree = sh_degree
self.mininum_kernel_size = mininum_kernel_size
self.scaling_bias = scaling_bias
self.opacity_bias = opacity_bias
self.device = device
self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
if scaling_activation == "exp":
self._scaling_activation = torch.exp
self._inverse_scaling_activation = torch.log
elif scaling_activation == "softplus":
self._scaling_activation = F.softplus
self._inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x))
self._opacity_activation = torch.sigmoid
self._inverse_opacity_activation = lambda x: torch.log(x / (1 - x))
self.scale_bias = self._inverse_scaling_activation(torch.tensor(self.scaling_bias)).to(self.device)
self.rots_bias = torch.zeros(4, device=self.device)
self.rots_bias[0] = 1
self.opacity_bias_val = self._inverse_opacity_activation(torch.tensor(self.opacity_bias)).to(self.device)
self._storage = {}
def _get_store(self, name):
return self._storage.get(name)
def _set_store(self, name, value):
self._storage[name] = value
@property
def _xyz(self):
return self._get_store("_xyz")
@_xyz.setter
def _xyz(self, value):
if value is None:
self._set_store("_xyz", None)
self._set_store("xyz", None)
return
self._set_store("_xyz", value)
self._set_store("xyz", value * self.aabb[None, 3:] + self.aabb[None, :3])
@property
def get_xyz(self):
return self._get_store("xyz")
@property
def _features_dc(self):
return self._get_store("_features_dc")
@_features_dc.setter
def _features_dc(self, value):
self._set_store("_features_dc", value)
@property
def _opacity(self):
return self._get_store("_opacity")
@_opacity.setter
def _opacity(self, value):
if value is None:
self._set_store("_opacity", None)
self._set_store("opacity", None)
return
self._set_store("_opacity", value)
self._set_store("opacity", self._opacity_activation(value + self.opacity_bias_val))
@property
def get_opacity(self):
return self._get_store("opacity")
@property
def _scaling(self):
return self._get_store("_scaling")
@_scaling.setter
def _scaling(self, value):
if value is None:
self._set_store("_scaling", None)
self._set_store("scaling", None)
return
self._set_store("_scaling", value)
s = self._scaling_activation(value + self.scale_bias)
s = torch.square(s) + self.mininum_kernel_size ** 2
self._set_store("scaling", torch.sqrt(s))
@property
def get_scaling(self):
return self._get_store("scaling")
@property
def _rotation(self):
return self._get_store("_rotation")
@_rotation.setter
def _rotation(self, value):
self._set_store("_rotation", value)
_DEFAULT_TRANSFORM = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
def render_tensors(self):
# Render-ready (activated, world-space) tensors for the generic SPLAT type. The axis transform
# (a 3x3 rotation, object frame -> viewer Y-up) is baked into positions and rotations.
# Returns float tensors on the intermediate device: positions (N,3), scales (N,3) linear,
# rotations (N,4) wxyz, opacities (N,1) in [0,1], sh (N,K,3) coefficients.
xyz = self.get_xyz.float()
scaling = self.get_scaling.float()
opacity = self.get_opacity.float()
rotation = (self._rotation + self.rots_bias[None, :]).float()
sh = self._features_dc.float() # (N, K, 3)
T = torch.as_tensor(self._DEFAULT_TRANSFORM, dtype=torch.float32, device=xyz.device)
xyz = xyz @ T.T
rotation = _matrix_to_quat(torch.matmul(T, _quat_to_matrix(rotation)))
rotation = rotation / torch.linalg.norm(rotation, dim=-1, keepdim=True)
out_device = comfy.model_management.intermediate_device()
return (
xyz.to(out_device).contiguous(), scaling.to(out_device).contiguous(),
rotation.to(out_device).contiguous(), opacity.to(out_device).contiguous(),
sh.to(out_device).contiguous(),
)
def _quat_to_matrix(q):
q = q / torch.linalg.norm(q, dim=-1, keepdim=True)
w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
R = torch.stack([
1 - 2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y),
2*(x*y + w*z), 1 - 2*(x*x + z*z), 2*(y*z - w*x),
2*(x*z - w*y), 2*(y*z + w*x), 1 - 2*(x*x + y*y),
], dim=-1).reshape(-1, 3, 3)
return R
def _matrix_to_quat(R):
trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
q = torch.zeros((R.shape[0], 4), dtype=R.dtype, device=R.device)
s = torch.sqrt(torch.clamp(trace + 1, min=0)) * 2
q[:, 0] = 0.25 * s
denom = torch.where(s != 0, s, torch.ones_like(s))
q[:, 1] = (R[:, 2, 1] - R[:, 1, 2]) / denom
q[:, 2] = (R[:, 0, 2] - R[:, 2, 0]) / denom
q[:, 3] = (R[:, 1, 0] - R[:, 0, 1]) / denom
m01 = (R[:, 0, 0] >= R[:, 1, 1]) & (R[:, 0, 0] >= R[:, 2, 2]) & (s == 0)
s1 = torch.sqrt(torch.clamp(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2], min=0)) * 2
q[m01, 0] = (R[m01, 2, 1] - R[m01, 1, 2]) / s1[m01]
q[m01, 1] = 0.25 * s1[m01]
q[m01, 2] = (R[m01, 0, 1] + R[m01, 1, 0]) / s1[m01]
q[m01, 3] = (R[m01, 0, 2] + R[m01, 2, 0]) / s1[m01]
m11 = (R[:, 1, 1] > R[:, 0, 0]) & (R[:, 1, 1] >= R[:, 2, 2]) & (s == 0)
s2 = torch.sqrt(torch.clamp(1 + R[:, 1, 1] - R[:, 0, 0] - R[:, 2, 2], min=0)) * 2
q[m11, 0] = (R[m11, 0, 2] - R[m11, 2, 0]) / s2[m11]
q[m11, 1] = (R[m11, 0, 1] + R[m11, 1, 0]) / s2[m11]
q[m11, 2] = 0.25 * s2[m11]
q[m11, 3] = (R[m11, 1, 2] + R[m11, 2, 1]) / s2[m11]
m21 = (R[:, 2, 2] > R[:, 0, 0]) & (R[:, 2, 2] > R[:, 1, 1]) & (s == 0)
s3 = torch.sqrt(torch.clamp(1 + R[:, 2, 2] - R[:, 0, 0] - R[:, 1, 1], min=0)) * 2
q[m21, 0] = (R[m21, 1, 0] - R[m21, 0, 1]) / s3[m21]
q[m21, 1] = (R[m21, 0, 2] + R[m21, 2, 0]) / s3[m21]
q[m21, 2] = (R[m21, 1, 2] + R[m21, 2, 1]) / s3[m21]
q[m21, 3] = 0.25 * s3[m21]
return q / torch.linalg.norm(q, dim=-1, keepdim=True)
def build_gaussian_models(decoder, points_pred: dict, pred: dict):
# Assemble GaussianModels from the elastic decoder layout. decoder is the ElasticGaussianFixedlenDecoder
# (carries layout / rep_config / _get_offset)
x = points_pred
offset = decoder._get_offset(pred['features'])
h = pred["features"]
ret = []
for i in range(h.shape[0]):
g = GaussianModel(
sh_degree=0,
aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
mininum_kernel_size=decoder.rep_config['filter_kernel_size_3d'],
scaling_bias=decoder.rep_config['scaling_bias'],
opacity_bias=decoder.rep_config['opacity_bias'],
scaling_activation=decoder.rep_config['scaling_activation'],
device=h.device,
)
_x = x["points"][i, :, None, :]
for k, v in decoder.layout.items():
if k == '_xyz':
setattr(g, k, (offset[i] + _x).flatten(0, 1))
elif k in ('_xyz_center', '_offset_scale'):
continue
else:
feats = h[i][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
setattr(g, k, feats * decoder.rep_config['lr'][k])
ret.append(g)
return ret

View File

@ -0,0 +1,326 @@
# TripoSplat flow-matching denoiser (LatentSeqMMFlowModel). Registered as a ModelType.FLOW arch and
# driven by the standard KSampler; jointly denoises the (B, 8192, 16) latent and a (B, 1, 5) camera token
# carried as a 2-element nested latent.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.model_management
import comfy.patcher_extension
import comfy.rmsnorm
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope
class MultiHeadRMSNorm(nn.Module):
def __init__(self, dim, heads, dtype=None, device=None):
super().__init__()
self.gamma = nn.Parameter(torch.empty(heads, dim, dtype=dtype, device=device))
def forward(self, x):
x = comfy.rmsnorm.rms_norm(x)
return x * comfy.model_management.cast_to(self.gamma, x.dtype, x.device)
# Positional embeddings
class RePo3DRotaryEmbedding(nn.Module):
def __init__(self, model_channels, num_heads, head_dim, repo_hidden_ratio=0.125, max_freq=16.0,
dtype=None, device=None, operations=None):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
repo_hidden_size = int(model_channels * repo_hidden_ratio)
self.norm = operations.LayerNorm(model_channels, dtype=dtype, device=device)
self.gate_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device)
self.content_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device)
self.act = nn.SiLU()
self.final_map = operations.Linear(repo_hidden_size, 3 * num_heads, bias=False, dtype=dtype, device=device)
self.dim_0 = 2 * (head_dim // 6)
self.dim_1 = 2 * (head_dim // 6)
self.dim_2 = head_dim - self.dim_0 - self.dim_1
dims = [self.dim_0, self.dim_1, self.dim_2]
freqs_list = []
for d in dims:
freq_dim = d // 2
freqs_list.append(torch.linspace(1.0, float(max_freq), steps=freq_dim, dtype=torch.float32))
self.freqs_0 = nn.Parameter(freqs_list[0])
self.freqs_1 = nn.Parameter(freqs_list[1])
self.freqs_2 = nn.Parameter(freqs_list[2])
def forward(self, hidden_states):
h = self.norm(hidden_states)
feat = self.act(self.gate_map(h)) * self.content_map(h)
out = self.final_map(feat)
B, L, _ = out.shape
delta_pos = out.reshape(B, L, self.num_heads, 3)
f0 = comfy.model_management.cast_to(self.freqs_0, torch.float32, out.device)
f1 = comfy.model_management.cast_to(self.freqs_1, torch.float32, out.device)
f2 = comfy.model_management.cast_to(self.freqs_2, torch.float32, out.device)
ang_0 = delta_pos[..., 0].unsqueeze(-1) * f0 * torch.pi
ang_1 = delta_pos[..., 1].unsqueeze(-1) * f1 * torch.pi
ang_2 = delta_pos[..., 2].unsqueeze(-1) * f2 * torch.pi
ang = torch.cat([ang_0, ang_1, ang_2], dim=-1).float() # (B, L, heads, head_dim/2)
cos, sin = ang.cos(), ang.sin()
return torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*ang.shape, 2, 2)
class PcdAbsolutePositionEmbedder(nn.Module):
# Sinusoidal absolute position embedding. Two fixed schedules are used in TripoSplat:
# "pow2" (flow-model latent anchors) and "log2" (octree / gaussian decoders).
def __init__(self, channels: int, in_channels: int = 3, max_res: int = 16, schedule: str = "pow2"):
super().__init__()
self.channels = channels
self.in_channels = in_channels
self.max_res = max_res
self.schedule = schedule
self.freq_dim = channels // in_channels // 2
def _freqs(self, device):
if self.schedule == "pow2":
freqs_2exp = torch.arange(self.max_res, dtype=torch.float32, device=device)
res_dim = max(0, self.freq_dim - self.max_res)
freqs_res = (torch.arange(res_dim, dtype=torch.float32, device=device) / max(res_dim, 1) * self.max_res
if res_dim > 0 else torch.empty(0, device=device))
freqs = torch.cat([freqs_2exp, freqs_res], dim=0)[:self.freq_dim]
return torch.pow(2.0, freqs) * 2.0 # *2 folds this schedule's 2*pi into the shared *pi below
logs = torch.linspace(0.0, float(self.max_res), steps=self.freq_dim, dtype=torch.float32, device=device)
return torch.pow(2.0, logs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
x = x.float()
*dims, D = x.shape
out = torch.outer(x.reshape(-1), self._freqs(x.device)) * torch.pi
out = torch.cat([out.sin(), out.cos()], dim=-1).reshape(*dims, -1)
if out.shape[-1] < self.channels:
out = torch.cat([out, torch.zeros(*dims, self.channels - out.shape[-1],
device=out.device, dtype=out.dtype)], dim=-1)
return out.to(orig_dtype)
def attention(q, k, v, transformer_options=None):
# q, k, v: (B, L, heads, dim) -> (B, L, heads, dim). Shared optimized_attention call convention.
out = optimized_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), heads=q.shape[2],
skip_reshape=True, skip_output_reshape=True, low_precision_attention=False,
transformer_options=transformer_options)
return out.transpose(1, 2)
# Transformer building blocks
class MLP(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(in_channels, hidden_channels, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(hidden_channels, out_channels, dtype=dtype, device=device),
)
def forward(self, x):
return self.mlp(x)
class RopeMultiHeadAttention(nn.Module):
def __init__(self, channels, num_heads, qkv_bias=True, qk_rms_norm=False, use_rope=False,
dtype=None, device=None, operations=None):
super().__init__()
self.channels = channels
self.num_heads = num_heads
self.head_dim = channels // num_heads
self.qk_rms_norm = qk_rms_norm
self.use_rope = use_rope
self.qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
if self.qk_rms_norm:
self.q_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.k_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.out = operations.Linear(channels, channels, dtype=dtype, device=device)
def forward(self, x, rope_emb=None, transformer_options=None):
B, L, C = x.shape
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(2)
if self.use_rope:
q, k = apply_rope(q, k, rope_emb)
if self.qk_rms_norm:
q = self.q_norm(q)
k = self.k_norm(k)
h = attention(q, k, v, transformer_options) # (B, L, heads, dim)
return self.out(h.reshape(B, L, C))
class UnifiedTransformerBlock(nn.Module):
def __init__(self, channels, num_heads, mlp_ratio=4.0,
use_rope=False, qk_rms_norm=False, qkv_bias=True,
modulation=True, share_mod=False,
dtype=None, device=None, operations=None):
super().__init__()
self.modulation = modulation
self.share_mod = share_mod
self.norm1 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device)
self.attn = RopeMultiHeadAttention(channels, num_heads=num_heads,
qkv_bias=qkv_bias, use_rope=use_rope, qk_rms_norm=qk_rms_norm,
dtype=dtype, device=device, operations=operations)
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
if modulation:
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
self.shift_table = nn.Parameter(torch.empty(1, 6 * channels, dtype=dtype, device=device))
def forward(self, x, mod=None, rotary_emb=None, transformer_options=None):
if self.modulation:
if not self.share_mod:
mod = self.adaLN_modulation(mod)
mod = mod + comfy.model_management.cast_to(self.shift_table, mod.dtype, mod.device)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
x = torch.addcmul(x, self.attn(h, rope_emb=rotary_emb, transformer_options=transformer_options), gate_msa.unsqueeze(1))
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
else:
x = x + self.attn(self.norm1(x), rope_emb=rotary_emb, transformer_options=transformer_options)
x = x + self.mlp(self.norm2(x))
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
emb = self.timestep_embedding(t, self.frequency_embedding_size)
return self.mlp(emb.to(self.mlp[0].weight.dtype))
class LatentSeqMMFlowModel(nn.Module):
def __init__(self, image_model=None, q_token_length=8192, in_channels=16, model_channels=1024,
cond_channels=1280, out_channels=16, num_blocks=24, num_refiner_blocks=2,
num_heads=None, num_head_channels=64, cam_channels=5, cond2_channels=128,
mlp_ratio=4, share_mod=True, qk_rms_norm=True,
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
self.q_token_length = q_token_length
self.in_channels = in_channels
self.cam_channels = cam_channels
self.model_channels = model_channels
self.cond_channels = cond_channels
self.cond2_channels = cond2_channels
self.out_channels = out_channels
self.num_blocks = num_blocks
self.num_refiner_blocks = num_refiner_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.share_mod = share_mod
self.qk_rms_norm = qk_rms_norm
factory_kwargs = dict(dtype=dtype, device=device)
op_kwargs = dict(operations=operations, **factory_kwargs)
self.t_embedder = TimestepEmbedder(model_channels, **op_kwargs)
if share_mod:
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, **factory_kwargs))
self.input_layer = operations.Linear(in_channels, model_channels, **factory_kwargs)
self.cond_embedder = operations.Linear(cond_channels, model_channels, **factory_kwargs)
self.cond_embedder2 = operations.Linear(cond2_channels, model_channels, **factory_kwargs) if cond2_channels is not None else None
# Fixed Sobol (low-discrepancy) 3D anchor positions for the latent tokens, used as positional encoding.
# The embedder is parameter-free and the anchors are fixed, precompute once.
sobol_seq = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123).draw(q_token_length)
pos_emb = PcdAbsolutePositionEmbedder(model_channels)(sobol_seq.unsqueeze(0))
self.register_buffer("pos_emb", pos_emb, persistent=False)
# RePo3DRotaryEmbedding layers for the refiner and main blocks
repo_kwargs = dict(num_heads=self.num_heads, head_dim=num_head_channels, **op_kwargs)
self.noise_repo_layers = nn.ModuleList(
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)])
self.context_repo_layers = nn.ModuleList(
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)])
self.repo_layers = nn.ModuleList(
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_blocks)])
# Refiner blocks
block_kwargs = dict(num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, use_rope=True, qk_rms_norm=self.qk_rms_norm, **op_kwargs)
self.noise_refiner = nn.ModuleList(
[UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_refiner_blocks)])
self.context_refiner = nn.ModuleList(
[UnifiedTransformerBlock(model_channels, modulation=False, **block_kwargs) for _ in range(num_refiner_blocks)])
self.cam_refiner = MLP(self.cam_channels, model_channels, model_channels, **op_kwargs)
self.blocks = nn.ModuleList(
[UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_blocks)])
self.shift_table = nn.Parameter(torch.empty(1, 2, model_channels, **factory_kwargs))
self.out_layer = operations.Linear(model_channels, out_channels, **factory_kwargs)
self.cam_out_layer = operations.Linear(model_channels, cam_channels, **factory_kwargs)
def forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, t, context, ref_latents, transformer_options, **kwargs)
def _forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs):
# x is the unpacked nested latent: [latent (B,8192,in_channels), camera (B,1,cam_channels)].
# context == feature1.
z, camera = x[0], x[1]
feat1 = context
h_x = self.input_layer(z)
h_cond = self.cond_embedder(feat1)
if ref_latents is not None and self.cond_embedder2 is not None:
# Flatten the Flux2 VAE latent (B,128,h,w) to a token sequence and front-pad to feat1's length
# (the pad count = feat1's prefix tokens: DINOv3 cls + registers), then add to the context.
feat2 = ref_latents[0].flatten(2).transpose(1, 2)
feat2 = F.pad(feat2, (0, 0, feat1.shape[1] - feat2.shape[1], 0))
h_cond = h_cond + self.cond_embedder2(feat2.to(h_cond.dtype))
t_emb = self.t_embedder(t)
t_mod = self.adaLN_modulation(t_emb) if self.share_mod else t_emb
h_x = h_x + self.pos_emb.to(z)
for i, block in enumerate(self.noise_refiner):
h_x = block(h_x, mod=t_mod, rotary_emb=self.noise_repo_layers[i](h_x), transformer_options=transformer_options)
for i, block in enumerate(self.context_refiner):
h_cond = block(h_cond, mod=None, rotary_emb=self.context_repo_layers[i](h_cond), transformer_options=transformer_options)
cam = camera.to(z)
h_cam = self.cam_refiner(cam)
h = torch.cat([h_x, h_cond, h_cam], dim=1)
for i, block in enumerate(self.blocks):
h = block(h, mod=t_mod, rotary_emb=self.repo_layers[i](h), transformer_options=transformer_options)
h_x = F.layer_norm(h[:, :z.shape[1]].float(), h.shape[-1:]).to(z)
h_cam = F.layer_norm(h[:, -cam.shape[1]:].float(), h.shape[-1:]).to(z)
shift, scale = (comfy.model_management.cast_to(self.shift_table, t_emb.dtype, t_emb.device) + t_emb.unsqueeze(1)).chunk(2, dim=1)
scale = 1 + scale
h_x = torch.addcmul(shift, h_x, scale)
h_cam = torch.addcmul(shift, h_cam, scale)
return self.out_layer(h_x), self.cam_out_layer(h_cam)

View File

@ -0,0 +1,91 @@
# Live preview for TripoSplat: decode an x0 estimate into a coarse gaussian splat and render it with a perspective orbit camera.
import numpy as np
from PIL import Image
_C0 = 0.28209479177387814
_LATENT_TOKENS = 8192 # q_token_length
_LATENT_CH = 16 # in_channels
_OBJECT_TO_VIEWER = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], np.float32) # object frame -> viewer Y-up frame
def _view_matrix(yaw_deg, pitch_deg):
y, p = np.radians(yaw_deg), np.radians(pitch_deg)
Ry = np.array([[np.cos(y), 0, np.sin(y)], [0, 1, 0], [-np.sin(y), 0, np.cos(y)]], np.float32)
Rx = np.array([[1, 0, 0], [0, np.cos(p), -np.sin(p)], [0, np.sin(p), np.cos(p)]], np.float32)
return Rx @ Ry
def render_splat(xyz, rgb, scale, opacity=None, yaw=35.0, pitch=30.0, size=320, min_px=2, gain=1.0,
max_px=9, min_opacity=0.0, fov=35.0, dist=2.2):
# Project gaussian centers with a perspective camera and paint each as a filled disk whose screen
# radius follows the gaussian's world-space scale, composited with a nearest-wins z-buffer.
# gain scales the footprint (≈ std spanned), `min_px`/`max_px` clamp the on-screen radius.
pts = xyz.astype(np.float32) @ _OBJECT_TO_VIEWER.T
v = pts @ _view_matrix(yaw, pitch).T
zc = v[:, 2] + dist
keep = zc > 1e-2
if opacity is not None and min_opacity > 0.0: # culls gaussians with very low opacity
keep = keep & (opacity > min_opacity)
v, zc, scale = v[keep], zc[keep], scale[keep]
col = (np.clip(rgb, 0, 1)[:, :3] * 255).astype(np.uint8)[keep]
if v.shape[0] == 0:
return Image.fromarray(np.zeros((size, size, 3), np.uint8))
f = (size / 2) / np.tan(np.radians(fov) / 2)
cx = size / 2 + f * v[:, 0] / zc
cy = size / 2 + f * v[:, 1] / zc
radius = np.clip(np.round(f * scale / zc * gain), min_px, max_px).astype(np.int32)
# Expand each splat to its disk pixels, bucketed by integer radius so it stays vectorized.
px, py, pz, pc = [], [], [], []
for r in range(int(radius.min()), int(radius.max()) + 1):
m = radius == r
if not m.any():
continue
dy, dx = np.mgrid[-r:r + 1, -r:r + 1]
disk = (dx * dx + dy * dy) <= r * r
ox, oy = dx[disk], dy[disk]
px.append((cx[m, None] + ox).ravel())
py.append((cy[m, None] + oy).ravel())
pz.append(np.repeat(zc[m], ox.size))
pc.append(np.repeat(col[m], ox.size, axis=0))
px, py = np.concatenate(px), np.concatenate(py)
pz, pc = np.concatenate(pz), np.concatenate(pc)
xi = np.clip(px, 0, size - 1).astype(np.int64)
yi = np.clip(py, 0, size - 1).astype(np.int64)
# Nearest-wins z-buffer: pack (quantized depth, source index), per-pixel min picks the closest
# splat, then decode the winning index back to its color.
pid = yi * size + xi
q = np.clip((pz * 1024.0).astype(np.int64), 0, (1 << 20) - 1) # near = small
key = (q << 32) | np.arange(pid.size, dtype=np.int64)
buf = np.full(size * size, 1 << 62, np.int64)
np.minimum.at(buf, pid, key)
img = np.zeros((size * size, 3), np.uint8)
hit = buf < (1 << 62)
img[hit] = pc[buf[hit] & 0xFFFFFFFF]
return Image.fromarray(img.reshape(size, size, 3))
def _extract_latent(x0):
# x0 from the sampler callback is the nested latent packed to (B, 1, TOKENS*CH + 1*5);
# the plain single-latent case is (B, TOKENS, CH). Return the (B, TOKENS, CH) latent stream.
if x0.ndim == 3 and x0.shape[1] == _LATENT_TOKENS and x0.shape[2] == _LATENT_CH:
return x0
flat = x0.reshape(x0.shape[0], -1)
return flat[:, :_LATENT_TOKENS * _LATENT_CH].reshape(x0.shape[0], _LATENT_TOKENS, _LATENT_CH)
def decode_x0_to_image(decoder, x0, cfg):
# Decode x0 at a coarse octree level / few gaussians and render a preview image.
latent = _extract_latent(x0)
fsm = decoder.first_stage_model
gaussian = fsm.decode(latent.to(decoder.device, decoder.vae_dtype),
num_gaussians=cfg.get("gaussians", 16384), level=cfg.get("level", 5))[0]
xyz = gaussian.get_xyz.float().cpu().numpy()
rgb = gaussian._features_dc.float().cpu().numpy()[:, 0, :] * _C0 + 0.5
scale = gaussian.get_scaling.float().cpu().numpy().max(axis=1) # per-splat world radius (largest axis)
opacity = gaussian.get_opacity.float().cpu().numpy()[:, 0]
return render_splat(xyz, rgb, scale, opacity=opacity, yaw=cfg.get("yaw", 35.0), pitch=cfg.get("pitch", 30.0),
size=cfg.get("size", 320), min_px=1, gain=1.0, max_px=cfg.get("point_size", 3),
min_opacity=0.01)

382
comfy/ldm/triposplat/vae.py Normal file
View File

@ -0,0 +1,382 @@
# TripoSplat gaussian decoder ("VAE"): an octree probability decoder picks point coords, then an
# elastic-gaussian decoder predicts per-point gaussian params. OctreeGaussianDecoder.decode() returns
# a Gaussian. The octree sampler uses the global torch RNG (no generator) like upstream, so seed it for repeatable decodes.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.model_management
import comfy.ops
from .gaussian import build_gaussian_models
from .model import MultiHeadRMSNorm, MLP, PcdAbsolutePositionEmbedder, attention
# Quasi-random sampling utilities (pure functions, dtype/device-agnostic)
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
def radical_inverse(base, n):
val = 0
inv_base = 1.0 / base
inv_base_n = inv_base
while n > 0:
digit = n % base
val += digit * inv_base_n
n //= base
inv_base_n *= inv_base
return val
def halton_sequence(dim, n):
return [radical_inverse(PRIMES[i], n) for i in range(dim)]
def hammersley_sequence(dim, n, num_samples):
return [n / num_samples] + halton_sequence(dim - 1, n)
def sample_probs(probs, counts, generator=None):
# Systematic resampling: distribute counts[r] draws across the P bins of row r
batch_shape = counts.shape
R = counts.numel()
P = probs.size(-1)
device = probs.device
probs = probs.reshape(R, P).to(torch.float32).clamp_min(0)
counts = counts.reshape(R).to(device=device, dtype=torch.long)
row_sums = probs.sum(1, keepdim=True)
probs = torch.where(row_sums == 0, probs.new_tensor(1.0 / P), probs / row_sums.clamp_min(1))
cdf = probs.cumsum(dim=1).clamp(max=1.0 - 1e-12)
Nmax = int(counts.max())
if Nmax == 0:
return counts.new_zeros(*batch_shape, P)
cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1)
grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax)
u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded)
idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1)
weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r]
out = torch.zeros(R, P, dtype=torch.float32, device=device)
out.scatter_add_(1, idx, weight)
return out.to(torch.long).view(*batch_shape, P)
class MultiHeadAttention(nn.Module):
def __init__(self, channels, num_heads, ctx_channels=None, type="self", qkv_bias=True, qk_rms_norm=False,
dtype=None, device=None, operations=None):
super().__init__()
assert channels % num_heads == 0
self.channels = channels
self.head_dim = channels // num_heads
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
self.num_heads = num_heads
self._type = type
self.qk_rms_norm = qk_rms_norm
if self._type == "self":
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
else:
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, dtype=dtype, device=device)
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, dtype=dtype, device=device)
if self.qk_rms_norm:
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.to_out = operations.Linear(channels, channels, dtype=dtype, device=device)
def forward(self, x, context=None):
B, L, C = x.shape
if self._type == "self":
q, k, v = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1).unbind(dim=2)
else:
Lkv = context.shape[1]
q = self.to_q(x).reshape(B, L, self.num_heads, -1)
k, v = self.to_kv(context).reshape(B, Lkv, 2, self.num_heads, -1).unbind(dim=2)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k = self.k_rms_norm(k)
h = attention(q, k, v)
return self.to_out(h.reshape(B, L, -1))
# Octree probability decoder
class LevelEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, max_period=1024,
dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
@staticmethod
def level_embedding(t, dim, max_period=1024):
half = dim // 2
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None] * 2 * torch.pi
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
emb = self.level_embedding(t, self.frequency_embedding_size, self.max_period)
return self.mlp(emb.to(self.mlp[0].weight.dtype))
class ModulatedTransformerCrossOnlyBlock(nn.Module):
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, share_mod=False,
qk_rms_norm_cross=True, qkv_bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.share_mod = share_mod
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads,
type="cross", qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
def forward(self, x, mod, context):
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
x = torch.addcmul(x, self.cross_attn(h, context), gate_msa.unsqueeze(1))
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
return x
class OctreeProbabilityFixedlenDecoder(nn.Module):
# Cross-attention transformer over octree coords -> per-node 8-way child occupancy logits.
def __init__(self, model_channels=1024, cond_channels=16, num_blocks=4, num_heads=16,
num_head_channels=64, mlp_ratio=4.0, share_mod=True,
qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
super().__init__()
self.model_channels = model_channels
self.cond_channels = cond_channels
self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.share_mod = share_mod
self.qk_rms_norm_cross = qk_rms_norm_cross
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
self.l_embedder = LevelEmbedder(model_channels, dtype=dtype, device=device, operations=operations)
if share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, dtype=dtype, device=device))
if cond_channels is not None:
self.blocks = nn.ModuleList([
ModulatedTransformerCrossOnlyBlock(
model_channels, ctx_channels=cond_channels, num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio, qk_rms_norm_cross=self.qk_rms_norm_cross,
share_mod=self.share_mod, dtype=dtype, device=device, operations=operations)
for _ in range(num_blocks)
])
self.out_proj = operations.Linear(model_channels, 8, dtype=dtype, device=device)
self.in_proj = operations.Linear(3, model_channels, dtype=dtype, device=device)
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
def forward(self, x, l, cond):
d = next(self.parameters()).dtype
B, L, _ = x.shape
h = self.in_proj(x.to(d)) + self.pos_embedder(x.reshape(-1, 3)).reshape(B, L, -1).to(d)
h = self.input_layer(h)
l_emb = self.l_embedder(l)
if self.share_mod:
l_emb = self.adaLN_modulation(l_emb)
cond = cond.to(d)
for block in self.blocks:
h = block(h, l_emb, cond)
h = F.layer_norm(h.float(), h.shape[-1:]).to(d)
logits = self.out_proj(h)
return {"logits": logits, "probs": torch.softmax(logits, dim=-1)}
@staticmethod
def sample(model, cond, num_points, level, temperature=1.0, generator=None):
B = cond.shape[0]
device = cond.device
child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]],
dtype=torch.long, device=device)
prev_coords_int = torch.zeros(B, 1, 3, dtype=torch.long, device=device)
prev_counts = torch.full((B, 1), num_points, dtype=torch.long, device=device)
prev_log_probs = torch.zeros(B, 1, dtype=torch.float32, device=device)
batch_indices_range = torch.arange(B, device=device).unsqueeze(1)
for lv in range(1, level + 1):
res_p = 1 << (lv - 1)
res = 1 << lv
parent_coords_norm = (prev_coords_int.to(torch.float32) + 0.5) / res_p
res_tensor = torch.full((B,), res, dtype=torch.long, device=device)
pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature
pred_probs = torch.softmax(pred_logits, dim=-1)
pred_log_probs = torch.log_softmax(pred_logits, dim=-1)
sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2)
pred_log_probs = pred_log_probs.flatten(1, 2)
prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1)
child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2)
mask = sampled > 0
max_valid = mask.sum(dim=1).max().item()
scatter_indices = mask.cumsum(dim=1) - 1
valid_scatter_indices = scatter_indices[mask]
valid_batch_indices = batch_indices_range.expand_as(mask)[mask]
next_prev_coords_int = torch.zeros(B, max_valid, 3, dtype=child_coords_int.dtype, device=device)
next_prev_coords_int[valid_batch_indices, valid_scatter_indices] = child_coords_int[mask]
next_prev_counts = torch.zeros(B, max_valid, dtype=sampled.dtype, device=device)
next_prev_counts[valid_batch_indices, valid_scatter_indices] = sampled[mask]
next_prev_log_probs = torch.zeros(B, max_valid, dtype=prev_log_probs.dtype, device=device)
next_prev_log_probs[valid_batch_indices, valid_scatter_indices] = (prev_log_probs_expanded + pred_log_probs)[mask]
prev_coords_int = next_prev_coords_int
prev_counts = next_prev_counts
prev_log_probs = next_prev_log_probs
res = 1 << level
prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points)
coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1)
rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device)
coords_norm = (coords_int.to(torch.float32) + rand) / res
return {"points": coords_norm, "log_probs": prev_log_probs}
# Elastic gaussian decoder
class TransformerCrossBlock(nn.Module):
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0,
qk_rms_norm=True, qk_rms_norm_cross=True, qkv_bias=True,
dtype=None, device=None, operations=None):
super().__init__()
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.self_attn = MultiHeadAttention(channels, num_heads=num_heads, type="self", qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm, dtype=dtype, device=device, operations=operations)
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, type="cross",
qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
def forward(self, x, context):
x = x + self.self_attn(self.norm1(x))
x = x + self.cross_attn(self.norm2(x), context)
x = x + self.mlp(self.norm3(x))
return x
class ElasticGaussianFixedlenDecoder(nn.Module):
# Cross-attention transformer over sampled octree points -> per-point gaussian params.
def __init__(self, in_channels=3, model_channels=1024, cond_channels=16, num_blocks=16, num_heads=16,
num_head_channels=64, mlp_ratio=4.0, *, representation_config=None,
qk_rms_norm=True, qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
super().__init__()
self.rep_config = representation_config or dict(
lr=dict(_xyz=1.0, _features_dc=1.0, _opacity=1.0, _scaling=1.0, _rotation=0.1),
perturb_offset=True, perturbe_size=1.5, offset_scale=0.05, num_gaussians=32,
filter_kernel_size_3d=0.0009, scaling_bias=0.004, opacity_bias=0.1,
scaling_activation="softplus",
)
self.out_channels = self._calc_layout()
self.model_channels = model_channels
self.cond_channels = cond_channels
self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
if cond_channels is not None:
self.blocks = nn.ModuleList([
TransformerCrossBlock(model_channels, ctx_channels=cond_channels,
num_heads=self.num_heads, mlp_ratio=self.mlp_ratio,
qk_rms_norm=qk_rms_norm, qk_rms_norm_cross=qk_rms_norm_cross,
dtype=dtype, device=device, operations=operations)
for _ in range(num_blocks)
])
self.in_proj = operations.Linear(in_channels, model_channels, dtype=dtype, device=device)
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
self.out_proj = operations.Linear(model_channels, self.out_channels, dtype=dtype, device=device)
self._build_perturbation()
def _calc_layout(self):
ng = self.rep_config['num_gaussians']
self.layout = {
'_xyz': {'shape': (ng, 3), 'size': ng * 3},
'_features_dc': {'shape': (ng, 1, 3), 'size': ng * 3},
'_scaling': {'shape': (ng, 3), 'size': ng * 3},
'_rotation': {'shape': (ng, 4), 'size': ng * 4},
'_opacity': {'shape': (ng, 1), 'size': ng},
}
self.layout['_offset_scale'] = {'shape': (ng, 1), 'size': ng}
start = 0
for k, v in self.layout.items():
v['range'] = (start, start + v['size'])
start += v['size']
return start
def _build_perturbation(self):
ng = self.rep_config['num_gaussians']
perturbation = torch.tensor([hammersley_sequence(3, i, ng) for i in range(ng)]).float()
perturbation = torch.atanh((perturbation * 2 - 1) / self.rep_config['perturbe_size'])
self.register_buffer('points_offset_perturbation', perturbation)
base = torch.tensor(self.rep_config['offset_scale'])
self.register_buffer('base_offset_scale', torch.log(torch.exp(base) - 1.0))
def _get_offset(self, h):
B = h.shape[0]
r = self.layout['_offset_scale']['range']
_offset_scale = F.softplus(
h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_offset_scale']['shape'])
+ comfy.model_management.cast_to(self.base_offset_scale, h.dtype, h.device))
r = self.layout['_xyz']['range']
offset = h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_xyz']['shape'])
offset = offset * self.rep_config['lr']['_xyz']
if self.rep_config['perturb_offset']:
offset = offset + comfy.model_management.cast_to(self.points_offset_perturbation, offset.dtype, offset.device)
offset = torch.tanh(offset) * 0.5 * self.rep_config['perturbe_size']
offset = offset * _offset_scale
return offset
def forward(self, x=None, cond=None):
pcd = x["points"]
d = next(self.parameters()).dtype
B, L, _ = pcd.shape
h = self.in_proj(pcd.to(d)) + self.pos_embedder(pcd.reshape(-1, 3)).reshape(B, L, -1).to(d)
h = self.input_layer(h)
cond = cond.to(d)
for block in self.blocks:
h = block(h, cond)
h = F.layer_norm(h.float(), h.shape[-1:]).to(h.dtype)
return {"features": self.out_proj(h)}
# Combined octree gaussian decoder (comfy first-stage model)
class OctreeGaussianDecoder(nn.Module):
_MAX_VOXEL_LEVEL = 8
def __init__(self, dtype=None, device=None, operations=None):
super().__init__()
if operations is None:
operations = comfy.ops.disable_weight_init
self.octree = OctreeProbabilityFixedlenDecoder(dtype=dtype, device=device, operations=operations)
self.gs = ElasticGaussianFixedlenDecoder(dtype=dtype, device=device, operations=operations)
@property
def gaussians_per_point(self) -> int:
return self.gs.rep_config['num_gaussians']
def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None):
# level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews.
# generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG.
level = self._MAX_VOXEL_LEVEL if level is None else level
num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point)
points_pred = OctreeProbabilityFixedlenDecoder.sample(
self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator,
)
pred = self.gs(x=points_pred, cond=latent)
return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item

View File

@ -47,6 +47,7 @@ import comfy.ldm.wan.model_animate
import comfy.ldm.wan.ar_model import comfy.ldm.wan.ar_model
import comfy.ldm.wan.model_wandancer import comfy.ldm.wan.model_wandancer
import comfy.ldm.hunyuan3d.model import comfy.ldm.hunyuan3d.model
import comfy.ldm.triposplat.model
import comfy.ldm.hidream.model import comfy.ldm.hidream.model
import comfy.ldm.chroma.model import comfy.ldm.chroma.model
import comfy.ldm.chroma_radiance.model import comfy.ldm.chroma_radiance.model
@ -1812,6 +1813,24 @@ class Hunyuan3Dv2_1(BaseModel):
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out return out
class TripoSplat(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.triposplat.model.LatentSeqMMFlowModel)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None) # DINOv3 token sequence -> cross-attention context.
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None) # Flux2 VAE image latent -> additive second conditioning.
if ref_latents is not None:
out['ref_latents'] = comfy.conds.CONDList(list(ref_latents))
latent_shapes = kwargs.get("latent_shapes", None) # {latent, camera} nested latent
if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
return out
class HiDream(BaseModel): class HiDream(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)

View File

@ -355,6 +355,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["use_x0"] = True dit_config["use_x0"] = True
else: else:
dit_config["use_x0"] = False dit_config["use_x0"] = False
if "{}__sequential__".format(key_prefix) in state_dict_keys: # sequential txt_ids
dit_config["use_sequential_txt_ids"] = True
else:
dit_config["use_sequential_txt_ids"] = False
else: else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
@ -718,6 +722,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
return dit_config return dit_config
if '{}cam_out_layer.weight'.format(key_prefix) in state_dict_keys and '{}repo_layers.0.final_map.weight'.format(key_prefix) in state_dict_keys: # TripoSplat
return {"image_model": "triposplat"}
if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1 if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1
return {"image_model": "hidream_o1"} return {"image_model": "hidream_o1"}

View File

@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2 import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae import comfy.ldm.hunyuan3d.vae
import comfy.ldm.triposplat.vae
import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.cogvideo.vae import comfy.ldm.cogvideo.vae
import comfy.ldm.hunyuan_video.vae import comfy.ldm.hunyuan_video.vae
@ -908,6 +909,16 @@ class VAE:
#Force cast it for --disable-dynamic-vram users until there is a true core fix. #Force cast it for --disable-dynamic-vram users until there is a true core fix.
if not comfy.memory_management.aimdo_enabled: if not comfy.memory_management.aimdo_enabled:
self.disable_offload = True self.disable_offload = True
elif "gs.base_offset_scale" in sd and "octree.out_proj.weight" in sd: # TripoSplat octree gaussian decoder
self.first_stage_model = comfy.ldm.triposplat.vae.OctreeGaussianDecoder()
self.latent_channels = 16
self.latent_dim = 1
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# The generic VAE.encode/decode path isn't used: VAEDecodeTripoSplat calls the gaussian
# decoder directly (structured GaussianSplat objects, not a tensor and reserves VRAM itself from num_gaussians.
def _no_generic_io(*args, **kwargs):
raise RuntimeError("TripoSplat gaussian decoder: use the 'TripoSplat Decode' (VAEDecodeTripoSplat)")
self.memory_used_encode = self.memory_used_decode = _no_generic_io
else: else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.") logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None self.first_stage_model = None

View File

@ -1547,6 +1547,30 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
latent_format = latent_formats.Hunyuan3Dv2mini latent_format = latent_formats.Hunyuan3Dv2mini
class TripoSplat(supported_models_base.BASE):
# Image -> 3D gaussian splat flow denoiser
unet_config = {
"image_model": "triposplat",
}
unet_extra_config = {}
sampling_settings = {
"shift": 3.0,
}
memory_usage_factor = 0.6
latent_format = latent_formats.TripoSplat
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
return model_base.TripoSplat(self, device=device)
def clip_target(self, state_dict={}):
return None
class HiDream(supported_models_base.BASE): class HiDream(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "hidream", "image_model": "hidream",
@ -2210,6 +2234,7 @@ models = [
Hunyuan3Dv2mini, Hunyuan3Dv2mini,
Hunyuan3Dv2, Hunyuan3Dv2,
Hunyuan3Dv2_1, Hunyuan3Dv2_1,
TripoSplat,
HiDream, HiDream,
HiDreamO1, HiDreamO1,
Chroma, Chroma,

View File

@ -5,7 +5,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class from comfy_api.internal.async_to_sync import create_sync_class
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from ._input_impl import VideoFromFile, VideoFromComponents from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, SPLAT, File3D
from . import _io_public as io from . import _io_public as io
from . import _ui_public as ui from . import _ui_public as ui
from comfy_execution.utils import get_executing_context from comfy_execution.utils import get_executing_context
@ -143,6 +143,7 @@ class Types:
VideoComponents = VideoComponents VideoComponents = VideoComponents
MESH = MESH MESH = MESH
VOXEL = VOXEL VOXEL = VOXEL
SPLAT = SPLAT
File3D = File3D File3D = File3D

View File

@ -65,6 +65,12 @@ class VideoInput(ABC):
buffer.seek(0) buffer.seek(0)
return buffer return buffer
def get_active_trim_window(self) -> tuple[float, float]:
"""Return the active trim as ``(start_time, duration)`` in seconds (start_time normalized
to ``>= 0``; ``duration == 0`` means "until the end"). Default: no trim; trimmable subclasses override.
"""
return 0.0, 0.0
# Provide a default implementation, but subclasses can provide optimized versions # Provide a default implementation, but subclasses can provide optimized versions
# if possible. # if possible.
def get_dimensions(self) -> tuple[int, int]: def get_dimensions(self) -> tuple[int, int]:

View File

@ -75,6 +75,12 @@ class VideoFromFile(VideoInput):
self.__file.seek(0) self.__file.seek(0)
return self.__file return self.__file
def get_active_trim_window(self) -> tuple[float, float]:
start_time = self.__start_time
if start_time < 0:
start_time = max(self._get_raw_duration() + start_time, 0.0)
return float(start_time), float(self.__duration)
def get_dimensions(self) -> tuple[int, int]: def get_dimensions(self) -> tuple[int, int]:
""" """
Returns the dimensions of the video input. Returns the dimensions of the video input.

View File

@ -28,7 +28,7 @@ if TYPE_CHECKING:
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class) prune_dict, shallow_clone_class)
from comfy_execution.graph_utils import ExecutionBlocker from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL, SVG as _SVG, File3D from ._util import MESH, VOXEL, SPLAT, SVG as _SVG, File3D
class FolderType(str, Enum): class FolderType(str, Enum):
@ -684,6 +684,10 @@ class Voxel(ComfyTypeIO):
class Mesh(ComfyTypeIO): class Mesh(ComfyTypeIO):
Type = MESH Type = MESH
@comfytype(io_type="SPLAT")
class Splat(ComfyTypeIO):
Type = SPLAT
@comfytype(io_type="FILE_3D") @comfytype(io_type="FILE_3D")
class File3DAny(ComfyTypeIO): class File3DAny(ComfyTypeIO):
@ -2320,6 +2324,7 @@ __all__ = [
"LossMap", "LossMap",
"Voxel", "Voxel",
"Mesh", "Mesh",
"Splat",
"File3DAny", "File3DAny",
"File3DGLB", "File3DGLB",
"File3DGLTF", "File3DGLTF",

View File

@ -1,5 +1,5 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents from .video_types import VideoContainer, VideoCodec, VideoComponents
from .geometry_types import VOXEL, MESH, File3D from .geometry_types import VOXEL, MESH, SPLAT, File3D
from .image_types import SVG from .image_types import SVG
__all__ = [ __all__ = [
@ -9,6 +9,7 @@ __all__ = [
"VideoComponents", "VideoComponents",
"VOXEL", "VOXEL",
"MESH", "MESH",
"SPLAT",
"File3D", "File3D",
"SVG", "SVG",
] ]

View File

@ -11,13 +11,32 @@ class VOXEL:
self.data = data self.data = data
class SPLAT:
"""A batch of 3D Gaussian splats in render-ready (activated, world-space) form.
Tensors are (B, N, ...) and zero-padded to a common N across the batch; `counts` (B,) holds the
real per-item lengths (None when rows are uniform and no slicing is needed). SH coefficients are
stored as (B, N, K, 3) with K = (sh_degree + 1)**2; the DC (diffuse) term is sh[..., 0, :].
"""
def __init__(self, positions: torch.Tensor, scales: torch.Tensor, rotations: torch.Tensor,
opacities: torch.Tensor, sh: torch.Tensor, counts: torch.Tensor | None = None):
self.positions = positions # (B, N, 3) world-space centers
self.scales = scales # (B, N, 3) linear (positive) per-axis std
self.rotations = rotations # (B, N, 4) quaternion wxyz (normalized)
self.opacities = opacities # (B, N, 1) in [0, 1]
self.sh = sh # (B, N, K, 3) spherical-harmonic color coefficients
self.counts = counts # (B,) real lengths, or None
class MESH: class MESH:
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor, def __init__(self, vertices: torch.Tensor, faces: torch.Tensor,
uvs: torch.Tensor | None = None, uvs: torch.Tensor | None = None,
vertex_colors: torch.Tensor | None = None, vertex_colors: torch.Tensor | None = None,
texture: torch.Tensor | None = None, texture: torch.Tensor | None = None,
vertex_counts: torch.Tensor | None = None, vertex_counts: torch.Tensor | None = None,
face_counts: torch.Tensor | None = None): face_counts: torch.Tensor | None = None,
unlit: bool = False):
assert (vertex_counts is None) == (face_counts is None), \ assert (vertex_counts is None) == (face_counts is None), \
"vertex_counts and face_counts must be provided together (both or neither)" "vertex_counts and face_counts must be provided together (both or neither)"
@ -30,6 +49,8 @@ class MESH:
# these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed. # these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed.
self.vertex_counts = vertex_counts self.vertex_counts = vertex_counts
self.face_counts = face_counts self.face_counts = face_counts
# Render flat / emissive (no scene lighting) when saved, e.g. for gaussian-splat-derived meshes.
self.unlit = unlit
class File3D: class File3D:

View File

@ -1,71 +1,71 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, Optional from typing import Any
from pydantic import BaseModel, Field, confloat, conint from pydantic import BaseModel, Field
class BFLOutputFormat(str, Enum):
png = 'png'
jpeg = 'jpeg'
class BFLFluxExpandImageRequest(BaseModel): class BFLFluxExpandImageRequest(BaseModel):
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.') prompt: str = Field(...)
prompt_upsampling: Optional[bool] = Field( prompt_upsampling: bool | None = Field(None)
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' seed: int | None = Field(None)
) top: int = Field(...)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.') bottom: int = Field(...)
top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image') left: int = Field(...)
bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image') right: int = Field(...)
left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image') steps: int = Field(...)
right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image') guidance: float = Field(...)
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') safety_tolerance: int = Field(6)
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process') output_format: str = Field("png")
safety_tolerance: Optional[conint(ge=0, le=6)] = Field( image: str = Field(None, description="A Base64-encoded string representing the image you wish to expand")
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand')
class BFLFluxFillImageRequest(BaseModel): class BFLFluxFillImageRequest(BaseModel):
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.') prompt: str = Field(...)
prompt_upsampling: Optional[bool] = Field( prompt_upsampling: bool | None = Field(None)
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' seed: int | None = Field(None)
steps: int = Field(...)
guidance: float = Field(...)
safety_tolerance: int = Field(6)
output_format: str = Field("png")
image: str = Field(
None, description="Base64-encoded string representing the image to modify. Can contain alpha mask if desired.",
) )
seed: Optional[int] = Field(None, description='The seed value for reproducibility.') mask: str = Field(
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') None, description="Base64-encoded string representing the mask of the areas you wish to modify."
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
) )
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
class BFLFluxEraseRequest(BaseModel):
image: str = Field(..., description="A Base64-encoded string representing the image to erase from.")
mask: str = Field(
...,
description="A Base64-encoded black/white mask matching the input dimensions; "
"white (255) marks areas to remove, black (0) marks areas to preserve.",
) )
image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.') dilate_pixels: int = Field(10)
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.') output_format: str = Field("png")
class BFLFluxVTORequest(BaseModel):
prompt: str = Field(
..., description="Natural-language styling instruction. Required field, but may be an empty string."
)
person: str = Field(..., description="A Base64-encoded string representing the person image.")
garment: str = Field(..., description="A Base64-encoded string representing the garment reference image.")
seed: int | None = Field(None)
safety_tolerance: int = Field(5)
output_format: str = Field("png")
class BFLFluxProGenerateRequest(BaseModel): class BFLFluxProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.') prompt: str = Field(...)
prompt_upsampling: Optional[bool] = Field( prompt_upsampling: bool | None = Field(None)
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' seed: int | None = Field(None)
) width: int = Field(1024, description="Must be a multiple of 32.")
seed: Optional[int] = Field(None, description='The seed value for reproducibility.') height: int = Field(768, description="Must be a multiple of 32.")
width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.') safety_tolerance: int = Field(6)
height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.') output_format: str = Field("png")
safety_tolerance: Optional[conint(ge=0, le=6)] = Field( image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
# image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
# None, description='Blend between the prompt and the image prompt.'
# )
class Flux2ProGenerateRequest(BaseModel): class Flux2ProGenerateRequest(BaseModel):
@ -83,55 +83,37 @@ class Flux2ProGenerateRequest(BaseModel):
input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation") input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation") input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation") input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
safety_tolerance: int | None = Field( safety_tolerance: int = Field(5)
5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5 output_format: str = Field("png")
)
output_format: str | None = Field(
"png", description="Output format for the generated image. Can be 'jpeg' or 'png'."
)
class BFLFluxKontextProGenerateRequest(BaseModel): class BFLFluxKontextProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for what you wannt to edit.') prompt: str = Field(...)
input_image: Optional[str] = Field(None, description='Image to edit in base64 format') input_image: str | None = Field(None, description="Image to edit in base64 format")
seed: Optional[int] = Field(None, description='The seed value for reproducibility.') seed: int | None = Field(None)
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process') guidance: float = Field(...)
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process') steps: int = Field(...)
safety_tolerance: Optional[conint(ge=0, le=2)] = Field( safety_tolerance: int = Field(2)
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.' output_format: str = Field("png")
) aspect_ratio: str | None = Field(None)
output_format: Optional[BFLOutputFormat] = Field( prompt_upsampling: bool | None = Field(None)
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
class BFLFluxProUltraGenerateRequest(BaseModel): class BFLFluxProUltraGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.') prompt: str = Field(...)
prompt_upsampling: Optional[bool] = Field( prompt_upsampling: bool | None = Field(None)
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' seed: int | None = Field(None)
) aspect_ratio: str | None = Field(None)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.') safety_tolerance: int = Field(6)
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.') output_format: str = Field("png")
safety_tolerance: Optional[conint(ge=0, le=6)] = Field( raw: bool | None = Field(None)
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
) image_prompt_strength: float | None = Field(None)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.')
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
None, description='Blend between the prompt and the image prompt.'
)
class BFLFluxProGenerateResponse(BaseModel): class BFLFluxProGenerateResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.") id: str = Field(...)
polling_url: str = Field(..., description="URL to poll for the generation result.") polling_url: str = Field(...)
cost: float | None = Field(None, description="Price in cents") cost: float | None = Field(None, description="Price in cents")
@ -145,7 +127,7 @@ class BFLStatus(str, Enum):
class BFLFluxStatusResponse(BaseModel): class BFLFluxStatusResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.") id: str = Field(...)
status: BFLStatus = Field(..., description="The status of the task.") status: BFLStatus = Field(...)
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).") result: dict[str, Any] | None = Field(None)
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0) progress: float | None = Field(None, ge=0.0, le=1.0)

View File

@ -155,7 +155,7 @@ class ClaudeNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ClaudeNode", node_id="ClaudeNode",
display_name="Anthropic Claude", display_name="Anthropic Claude",
category="text/partner/Anthropic", category="partner/text/Anthropic",
essentials_category="Text Generation", essentials_category="Text Generation",
description="Generate text responses with Anthropic's Claude models. " description="Generate text responses with Anthropic's Claude models. "
"Provide a text prompt and optionally one or more images for multimodal context.", "Provide a text prompt and optionally one or more images for multimodal context.",

View File

@ -206,7 +206,7 @@ class BeebleSwitchXVideoEdit(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="BeebleSwitchXVideoEdit", node_id="BeebleSwitchXVideoEdit",
display_name="Beeble SwitchX Video Edit", display_name="Beeble SwitchX Video Edit",
category="video/partner/Beeble", category="partner/video/Beeble",
description=( description=(
"Edit a video with Beeble SwitchX. Switches anything in the scene (background, " "Edit a video with Beeble SwitchX. Switches anything in the scene (background, "
"lighting, costume) while preserving the original subject's pixels and motion. " "lighting, costume) while preserving the original subject's pixels and motion. "
@ -302,7 +302,7 @@ class BeebleSwitchXImageEdit(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="BeebleSwitchXImageEdit", node_id="BeebleSwitchXImageEdit",
display_name="Beeble SwitchX Image Edit", display_name="Beeble SwitchX Image Edit",
category="image/partner/Beeble", category="partner/image/Beeble",
description=( description=(
"Edit a single image with Beeble SwitchX. Switches anything in the scene " "Edit a single image with Beeble SwitchX. Switches anything in the scene "
"(background, lighting, costume) while preserving the original subject's pixels. " "(background, lighting, costume) while preserving the original subject's pixels. "

View File

@ -4,17 +4,20 @@ from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bfl import ( from comfy_api_nodes.apis.bfl import (
BFLFluxEraseRequest,
BFLFluxExpandImageRequest, BFLFluxExpandImageRequest,
BFLFluxFillImageRequest, BFLFluxFillImageRequest,
BFLFluxKontextProGenerateRequest, BFLFluxKontextProGenerateRequest,
BFLFluxProGenerateResponse, BFLFluxProGenerateResponse,
BFLFluxProUltraGenerateRequest, BFLFluxProUltraGenerateRequest,
BFLFluxStatusResponse, BFLFluxStatusResponse,
BFLFluxVTORequest,
BFLStatus, BFLStatus,
Flux2ProGenerateRequest, Flux2ProGenerateRequest,
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
convert_mask_to_image,
download_url_to_image_tensor, download_url_to_image_tensor,
get_number_of_images, get_number_of_images,
poll_op, poll_op,
@ -22,19 +25,11 @@ from comfy_api_nodes.util import (
sync_op, sync_op,
tensor_to_base64_string, tensor_to_base64_string,
validate_aspect_ratio_string, validate_aspect_ratio_string,
validate_image_dimensions,
validate_string, validate_string,
) )
def convert_mask_to_image(mask: Input.Image):
"""
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
"""
mask = mask.unsqueeze(-1)
mask = torch.cat([mask] * 3, dim=-1)
return mask
class FluxProUltraImageNode(IO.ComfyNode): class FluxProUltraImageNode(IO.ComfyNode):
@classmethod @classmethod
@ -42,7 +37,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="FluxProUltraImageNode", node_id="FluxProUltraImageNode",
display_name="Flux 1.1 [pro] Ultra Image", display_name="Flux 1.1 [pro] Ultra Image",
category="image/partner/BFL", category="partner/image/BFL",
description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.", description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -160,7 +155,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id=cls.NODE_ID, node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME, display_name=cls.DISPLAY_NAME,
category="image/partner/BFL", category="partner/image/BFL",
description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.", description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -282,7 +277,7 @@ class FluxProExpandNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="FluxProExpandNode", node_id="FluxProExpandNode",
display_name="Flux.1 Expand Image", display_name="Flux.1 Expand Image",
category="image/partner/BFL", category="partner/image/BFL",
description="Outpaints image based on prompt.", description="Outpaints image based on prompt.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -419,7 +414,7 @@ class FluxProFillNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="FluxProFillNode", node_id="FluxProFillNode",
display_name="Flux.1 Fill Image", display_name="Flux.1 Fill Image",
category="image/partner/BFL", category="partner/image/BFL",
description="Inpaints image based on mask and prompt.", description="Inpaints image based on mask and prompt.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -519,6 +514,163 @@ class FluxProFillNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxEraseNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxEraseNode",
display_name="Flux Erase Image",
category="partner/image/BFL",
description="Removes the masked object from an image and reconstructs the background. "
"Paint the mask over what you want to erase.",
inputs=[
IO.Image.Input("image"),
IO.Mask.Input("mask", tooltip="White areas are removed; black areas are preserved."),
IO.Int.Input(
"dilate_pixels",
default=10,
min=0,
max=25,
tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"range_usd","min_usd":0.03,"max_usd":0.06,"format":{"approximate":true}}""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
mask: Input.Image,
dilate_pixels: int = 10,
) -> IO.NodeOutput:
validate_image_dimensions(image, min_width=256, min_height=256)
mask = resize_mask_to_image(mask, image)
mask = tensor_to_base64_string(convert_mask_to_image(mask))
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/erase-v1", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxEraseRequest(
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
mask=mask,
dilate_pixels=dilate_pixels,
),
)
def price_extractor(_r: BaseModel) -> float | None:
return None if initial_response.cost is None else initial_response.cost / 100
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
price_extractor=price_extractor,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxVTONode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxVTONode",
display_name="Flux Virtual Try-On",
category="partner/image/BFL",
description="Virtual try-on: dresses the person in the provided garment.",
inputs=[
IO.Image.Input("person", tooltip="Image of the person to dress."),
IO.Image.Input("garment", tooltip="Image of the garment to apply."),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Optional natural-language styling instruction (e.g. how the garment should fit).",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"range_usd","min_usd":0.0375,"max_usd":0.075,"format":{"approximate":true}}""",
),
)
@classmethod
async def execute(
cls,
person: Input.Image,
garment: Input.Image,
prompt: str = "",
seed: int = 0,
) -> IO.NodeOutput:
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/vto-v1", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxVTORequest(
prompt=prompt,
person=tensor_to_base64_string(person[:, :, :, :3]),
garment=tensor_to_base64_string(garment[:, :, :, :3]),
seed=seed,
),
)
def price_extractor(_r: BaseModel) -> float | None:
return None if initial_response.cost is None else initial_response.cost / 100
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
price_extractor=price_extractor,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class Flux2ProImageNode(IO.ComfyNode): class Flux2ProImageNode(IO.ComfyNode):
NODE_ID = "Flux2ProImageNode" NODE_ID = "Flux2ProImageNode"
@ -545,7 +697,7 @@ class Flux2ProImageNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id=cls.NODE_ID, node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME, display_name=cls.DISPLAY_NAME,
category="image/partner/BFL", category="partner/image/BFL",
description="Generates images synchronously based on prompt and resolution.", description="Generates images synchronously based on prompt and resolution.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -716,7 +868,7 @@ class Flux2ImageNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Flux2ImageNode", node_id="Flux2ImageNode",
display_name="Flux.2 Image", display_name="Flux.2 Image",
category="image/partner/BFL", category="partner/image/BFL",
description="Generate images via Flux.2 [pro] or Flux.2 [max] from a prompt and optional reference images.", description="Generate images via Flux.2 [pro] or Flux.2 [max] from a prompt and optional reference images.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -853,6 +1005,8 @@ class BFLExtension(ComfyExtension):
FluxKontextMaxImageNode, FluxKontextMaxImageNode,
FluxProExpandNode, FluxProExpandNode,
FluxProFillNode, FluxProFillNode,
FluxEraseNode,
FluxVTONode,
Flux2ProImageNode, Flux2ProImageNode,
Flux2MaxImageNode, Flux2MaxImageNode,
Flux2ImageNode, Flux2ImageNode,

View File

@ -31,7 +31,7 @@ class BriaImageEditNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="BriaImageEditNode", node_id="BriaImageEditNode",
display_name="Bria FIBO Image Edit", display_name="Bria FIBO Image Edit",
category="image/partner/Bria", category="partner/image/Bria",
description="Edit images using Bria latest model", description="Edit images using Bria latest model",
inputs=[ inputs=[
IO.Combo.Input("model", options=["FIBO"]), IO.Combo.Input("model", options=["FIBO"]),
@ -169,7 +169,7 @@ class BriaRemoveImageBackground(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="BriaRemoveImageBackground", node_id="BriaRemoveImageBackground",
display_name="Bria Remove Image Background", display_name="Bria Remove Image Background",
category="image/partner/Bria", category="partner/image/Bria",
description="Remove the background from an image using Bria RMBG 2.0.", description="Remove the background from an image using Bria RMBG 2.0.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -245,7 +245,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="BriaRemoveVideoBackground", node_id="BriaRemoveVideoBackground",
display_name="Bria Remove Video Background", display_name="Bria Remove Video Background",
category="video/partner/Bria", category="partner/video/Bria",
description="Remove the background from a video using Bria. ", description="Remove the background from a video using Bria. ",
inputs=[ inputs=[
IO.Video.Input("video"), IO.Video.Input("video"),

View File

@ -368,7 +368,7 @@ class ByteDanceImageNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ByteDanceImageNode", node_id="ByteDanceImageNode",
display_name="ByteDance Image", display_name="ByteDance Image",
category="image/partner/ByteDance", category="partner/image/ByteDance",
description="Generate images using ByteDance models via api based on prompt", description="Generate images using ByteDance models via api based on prompt",
inputs=[ inputs=[
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]), IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
@ -492,7 +492,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ByteDanceSeedreamNode", node_id="ByteDanceSeedreamNode",
display_name="ByteDance Seedream 4.5 & 5.0", display_name="ByteDance Seedream 4.5 & 5.0",
category="image/partner/ByteDance", category="partner/image/ByteDance",
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -754,7 +754,7 @@ class ByteDanceSeedreamNodeV2(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ByteDanceSeedreamNodeV2", node_id="ByteDanceSeedreamNodeV2",
display_name="ByteDance Seedream 4.5 & 5.0", display_name="ByteDance Seedream 4.5 & 5.0",
category="image/partner/ByteDance", category="partner/image/ByteDance",
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -920,7 +920,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ByteDanceTextToVideoNode", node_id="ByteDanceTextToVideoNode",
display_name="ByteDance Text to Video", display_name="ByteDance Text to Video",
category="video/partner/ByteDance", category="partner/video/ByteDance",
description="Generate video using ByteDance models via api based on prompt", description="Generate video using ByteDance models via api based on prompt",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -1048,7 +1048,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ByteDanceImageToVideoNode", node_id="ByteDanceImageToVideoNode",
display_name="ByteDance Image to Video", display_name="ByteDance Image to Video",
category="video/partner/ByteDance", category="partner/video/ByteDance",
description="Generate video using ByteDance models via api based on image and prompt", description="Generate video using ByteDance models via api based on image and prompt",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -1185,7 +1185,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ByteDanceFirstLastFrameNode", node_id="ByteDanceFirstLastFrameNode",
display_name="ByteDance First-Last-Frame to Video", display_name="ByteDance First-Last-Frame to Video",
category="video/partner/ByteDance", category="partner/video/ByteDance",
description="Generate video using prompt and first and last frames.", description="Generate video using prompt and first and last frames.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -1333,7 +1333,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ByteDanceImageReferenceNode", node_id="ByteDanceImageReferenceNode",
display_name="ByteDance Reference Images to Video", display_name="ByteDance Reference Images to Video",
category="video/partner/ByteDance", category="partner/video/ByteDance",
description="Generate video using prompt and reference images.", description="Generate video using prompt and reference images.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -1576,7 +1576,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ByteDance2TextToVideoNode", node_id="ByteDance2TextToVideoNode",
display_name="ByteDance Seedance 2.0 Text to Video", display_name="ByteDance Seedance 2.0 Text to Video",
category="video/partner/ByteDance", category="partner/video/ByteDance",
description="Generate video using Seedance 2.0 models based on a text prompt.", description="Generate video using Seedance 2.0 models based on a text prompt.",
inputs=[ inputs=[
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
@ -1677,7 +1677,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ByteDance2FirstLastFrameNode", node_id="ByteDance2FirstLastFrameNode",
display_name="ByteDance Seedance 2.0 First-Last-Frame to Video", display_name="ByteDance Seedance 2.0 First-Last-Frame to Video",
category="video/partner/ByteDance", category="partner/video/ByteDance",
description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.", description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.",
inputs=[ inputs=[
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
@ -1944,7 +1944,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ByteDance2ReferenceNode", node_id="ByteDance2ReferenceNode",
display_name="ByteDance Seedance 2.0 Reference to Video", display_name="ByteDance Seedance 2.0 Reference to Video",
category="video/partner/ByteDance", category="partner/video/ByteDance",
description="Generate, edit, or extend video using Seedance 2.0 with reference images, " description="Generate, edit, or extend video using Seedance 2.0 with reference images, "
"videos, and audio. Supports multimodal reference, video editing, and video extension.", "videos, and audio. Supports multimodal reference, video editing, and video extension.",
inputs=[ inputs=[
@ -2241,7 +2241,7 @@ class ByteDanceCreateImageAsset(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ByteDanceCreateImageAsset", node_id="ByteDanceCreateImageAsset",
display_name="ByteDance Create Image Asset", display_name="ByteDance Create Image Asset",
category="image/partner/ByteDance", category="partner/image/ByteDance",
description=( description=(
"Create a Seedance 2.0 personal image asset. Uploads the input image and " "Create a Seedance 2.0 personal image asset. Uploads the input image and "
"registers it in the given asset group. If group_id is empty, runs a real-person " "registers it in the given asset group. If group_id is empty, runs a real-person "
@ -2308,7 +2308,7 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ByteDanceCreateVideoAsset", node_id="ByteDanceCreateVideoAsset",
display_name="ByteDance Create Video Asset", display_name="ByteDance Create Video Asset",
category="video/partner/ByteDance", category="partner/video/ByteDance",
description=( description=(
"Create a Seedance 2.0 personal video asset. Uploads the input video and " "Create a Seedance 2.0 personal video asset. Uploads the input video and "
"registers it in the given asset group. If group_id is empty, runs a real-person " "registers it in the given asset group. If group_id is empty, runs a real-person "

View File

@ -144,7 +144,7 @@ class ByteDanceSeedNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ByteDanceSeedNode", node_id="ByteDanceSeedNode",
display_name="ByteDance Seed", display_name="ByteDance Seed",
category="text/partner/ByteDance", category="partner/text/ByteDance",
essentials_category="Text Generation", essentials_category="Text Generation",
description="Generate text responses with ByteDance's Seed 2.0 models. " description="Generate text responses with ByteDance's Seed 2.0 models. "
"Provide a text prompt and optionally one or more images or videos for multimodal context.", "Provide a text prompt and optionally one or more images or videos for multimodal context.",

View File

@ -69,7 +69,7 @@ class ElevenLabsSpeechToText(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ElevenLabsSpeechToText", node_id="ElevenLabsSpeechToText",
display_name="ElevenLabs Speech to Text", display_name="ElevenLabs Speech to Text",
category="audio/partner/ElevenLabs", category="partner/audio/ElevenLabs",
description="Transcribe audio to text. " description="Transcribe audio to text. "
"Supports automatic language detection, speaker diarization, and audio event tagging.", "Supports automatic language detection, speaker diarization, and audio event tagging.",
inputs=[ inputs=[
@ -210,7 +210,7 @@ class ElevenLabsVoiceSelector(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ElevenLabsVoiceSelector", node_id="ElevenLabsVoiceSelector",
display_name="ElevenLabs Voice Selector", display_name="ElevenLabs Voice Selector",
category="audio/partner/ElevenLabs", category="partner/audio/ElevenLabs",
description="Select a predefined ElevenLabs voice for text-to-speech generation.", description="Select a predefined ElevenLabs voice for text-to-speech generation.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -239,7 +239,7 @@ class ElevenLabsTextToSpeech(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ElevenLabsTextToSpeech", node_id="ElevenLabsTextToSpeech",
display_name="ElevenLabs Text to Speech", display_name="ElevenLabs Text to Speech",
category="audio/partner/ElevenLabs", category="partner/audio/ElevenLabs",
description="Convert text to speech.", description="Convert text to speech.",
inputs=[ inputs=[
IO.Custom(ELEVENLABS_VOICE).Input( IO.Custom(ELEVENLABS_VOICE).Input(
@ -414,7 +414,7 @@ class ElevenLabsAudioIsolation(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ElevenLabsAudioIsolation", node_id="ElevenLabsAudioIsolation",
display_name="ElevenLabs Voice Isolation", display_name="ElevenLabs Voice Isolation",
category="audio/partner/ElevenLabs", category="partner/audio/ElevenLabs",
description="Remove background noise from audio, isolating vocals or speech.", description="Remove background noise from audio, isolating vocals or speech.",
inputs=[ inputs=[
IO.Audio.Input( IO.Audio.Input(
@ -459,7 +459,7 @@ class ElevenLabsTextToSoundEffects(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ElevenLabsTextToSoundEffects", node_id="ElevenLabsTextToSoundEffects",
display_name="ElevenLabs Text to Sound Effects", display_name="ElevenLabs Text to Sound Effects",
category="audio/partner/ElevenLabs", category="partner/audio/ElevenLabs",
description="Generate sound effects from text descriptions.", description="Generate sound effects from text descriptions.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -555,7 +555,7 @@ class ElevenLabsInstantVoiceClone(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ElevenLabsInstantVoiceClone", node_id="ElevenLabsInstantVoiceClone",
display_name="ElevenLabs Instant Voice Clone", display_name="ElevenLabs Instant Voice Clone",
category="audio/partner/ElevenLabs", category="partner/audio/ElevenLabs",
description="Create a cloned voice from audio samples. " description="Create a cloned voice from audio samples. "
"Provide 1-8 audio recordings of the voice to clone.", "Provide 1-8 audio recordings of the voice to clone.",
inputs=[ inputs=[
@ -658,7 +658,7 @@ class ElevenLabsSpeechToSpeech(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ElevenLabsSpeechToSpeech", node_id="ElevenLabsSpeechToSpeech",
display_name="ElevenLabs Speech to Speech", display_name="ElevenLabs Speech to Speech",
category="audio/partner/ElevenLabs", category="partner/audio/ElevenLabs",
description="Transform speech from one voice to another while preserving the original content and emotion.", description="Transform speech from one voice to another while preserving the original content and emotion.",
inputs=[ inputs=[
IO.Custom(ELEVENLABS_VOICE).Input( IO.Custom(ELEVENLABS_VOICE).Input(
@ -793,7 +793,7 @@ class ElevenLabsTextToDialogue(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ElevenLabsTextToDialogue", node_id="ElevenLabsTextToDialogue",
display_name="ElevenLabs Text to Dialogue", display_name="ElevenLabs Text to Dialogue",
category="audio/partner/ElevenLabs", category="partner/audio/ElevenLabs",
description="Generate multi-speaker dialogue from text. Each dialogue entry has its own text and voice.", description="Generate multi-speaker dialogue from text. Each dialogue entry has its own text and voice.",
inputs=[ inputs=[
IO.Float.Input( IO.Float.Input(

View File

@ -300,7 +300,7 @@ class GeminiNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="GeminiNode", node_id="GeminiNode",
display_name="Google Gemini", display_name="Google Gemini",
category="text/partner/Gemini", category="partner/text/Gemini",
description="Generate text responses with Google's Gemini AI model. " description="Generate text responses with Google's Gemini AI model. "
"You can provide multiple types of inputs (text, images, audio, video) " "You can provide multiple types of inputs (text, images, audio, video) "
"as context for generating more relevant and meaningful responses.", "as context for generating more relevant and meaningful responses.",
@ -541,7 +541,7 @@ class GeminiInputFiles(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="GeminiInputFiles", node_id="GeminiInputFiles",
display_name="Gemini Input Files", display_name="Gemini Input Files",
category="text/partner/Gemini", category="partner/text/Gemini",
description="Loads and prepares input files to include as inputs for Gemini LLM nodes. " description="Loads and prepares input files to include as inputs for Gemini LLM nodes. "
"The files will be read by the Gemini model when generating a response. " "The files will be read by the Gemini model when generating a response. "
"The contents of the text file count toward the token limit. " "The contents of the text file count toward the token limit. "
@ -598,7 +598,7 @@ class GeminiImage(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="GeminiImageNode", node_id="GeminiImageNode",
display_name="Nano Banana (Google Gemini Image)", display_name="Nano Banana (Google Gemini Image)",
category="image/partner/Gemini", category="partner/image/Gemini",
description="Edit images synchronously via Google API.", description="Edit images synchronously via Google API.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -731,7 +731,7 @@ class GeminiImage2(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="GeminiImage2Node", node_id="GeminiImage2Node",
display_name="Nano Banana Pro (Google Gemini Image)", display_name="Nano Banana Pro (Google Gemini Image)",
category="image/partner/Gemini", category="partner/image/Gemini",
description="Generate or edit images synchronously via Google Vertex API.", description="Generate or edit images synchronously via Google Vertex API.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -869,7 +869,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="GeminiNanoBanana2", node_id="GeminiNanoBanana2",
display_name="Nano Banana 2", display_name="Nano Banana 2",
category="image/partner/Gemini", category="partner/image/Gemini",
description="Generate or edit images synchronously via Google Vertex API.", description="Generate or edit images synchronously via Google Vertex API.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -1085,7 +1085,7 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="GeminiNanoBanana2V2", node_id="GeminiNanoBanana2V2",
display_name="Nano Banana 2", display_name="Nano Banana 2",
category="image/partner/Gemini", category="partner/image/Gemini",
description="Generate or edit images synchronously via Google Vertex API.", description="Generate or edit images synchronously via Google Vertex API.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(

View File

@ -29,6 +29,11 @@ from comfy_api_nodes.util import (
) )
_GROK_VIDEO_MODEL_API_IDS = {
"grok-imagine-video-1.5": "grok-imagine-video-1.5-preview",
}
def _extract_grok_price(response) -> float | None: def _extract_grok_price(response) -> float | None:
if response.usage and response.usage.cost_in_usd_ticks is not None: if response.usage and response.usage.cost_in_usd_ticks is not None:
return response.usage.cost_in_usd_ticks / 10_000_000_000 return response.usage.cost_in_usd_ticks / 10_000_000_000
@ -49,7 +54,7 @@ class GrokImageNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="GrokImageNode", node_id="GrokImageNode",
display_name="Grok Image", display_name="Grok Image",
category="image/partner/Grok", category="partner/image/Grok",
description="Generate images using Grok based on a text prompt", description="Generate images using Grok based on a text prompt",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -223,7 +228,7 @@ class GrokImageEditNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="GrokImageEditNode", node_id="GrokImageEditNode",
display_name="Grok Image Edit", display_name="Grok Image Edit",
category="image/partner/Grok", category="partner/image/Grok",
description="Modify an existing image based on a text prompt", description="Modify an existing image based on a text prompt",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -364,7 +369,7 @@ class GrokImageEditNodeV2(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="GrokImageEditNodeV2", node_id="GrokImageEditNodeV2",
display_name="Grok Image Edit", display_name="Grok Image Edit",
category="image/partner/Grok", category="partner/image/Grok",
description="Modify an existing image based on a text prompt", description="Modify an existing image based on a text prompt",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -501,10 +506,14 @@ class GrokVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="GrokVideoNode", node_id="GrokVideoNode",
display_name="Grok Video", display_name="Grok Video",
category="video/partner/Grok", category="partner/video/Grok",
description="Generate video from a prompt or an image", description="Generate video from a prompt or an image",
inputs=[ inputs=[
IO.Combo.Input("model", options=["grok-imagine-video"]), IO.Combo.Input(
"model",
options=["grok-imagine-video", "grok-imagine-video-1.5"],
tooltip="grok-imagine-video-1.5 currently always requires an input image.",
),
IO.String.Input( IO.String.Input(
"prompt", "prompt",
multiline=True, multiline=True,
@ -540,7 +549,11 @@ class GrokVideoNode(IO.ComfyNode):
tooltip="Seed to determine if node should re-run; " tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.", "actual results are nondeterministic regardless of seed.",
), ),
IO.Image.Input("image", optional=True), IO.Image.Input(
"image",
optional=True,
tooltip="Optional starting image for grok-imagine-video. Required for grok-imagine-video-1.5.",
),
], ],
outputs=[ outputs=[
IO.Video.Output(), IO.Video.Output(),
@ -552,12 +565,16 @@ class GrokVideoNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]), depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"], inputs=["image"]),
expr=""" expr="""
( (
$rate := widgets.resolution = "720p" ? 0.07 : 0.05; $is15 := $contains(widgets.model, "1.5");
$rate := $is15
? (widgets.resolution = "720p" ? 0.2002 : 0.1144)
: (widgets.resolution = "720p" ? 0.07 : 0.05);
$imgCost := $is15 ? 0.0143 : 0.002;
$base := $rate * widgets.duration; $base := $rate * widgets.duration;
{"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base} {"type":"usd","usd": inputs.image.connected ? $base + $imgCost : $base}
) )
""", """,
), ),
@ -574,6 +591,8 @@ class GrokVideoNode(IO.ComfyNode):
seed: int, seed: int,
image: Input.Image | None = None, image: Input.Image | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if image is None and model == "grok-imagine-video-1.5":
raise ValueError(f"The '{model}' model requires an input image; connect one to the 'image' input.")
image_url = None image_url = None
if image is not None: if image is not None:
if get_number_of_images(image) != 1: if get_number_of_images(image) != 1:
@ -584,7 +603,7 @@ class GrokVideoNode(IO.ComfyNode):
cls, cls,
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"), ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
data=VideoGenerationRequest( data=VideoGenerationRequest(
model=model, model=_GROK_VIDEO_MODEL_API_IDS.get(model, model),
image=image_url, image=image_url,
prompt=prompt, prompt=prompt,
resolution=resolution, resolution=resolution,
@ -599,7 +618,7 @@ class GrokVideoNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete", status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse, response_model=VideoStatusResponse,
price_extractor=_extract_grok_price, price_extractor=_extract_grok_video_price if model == "grok-imagine-video-1.5" else _extract_grok_price,
) )
return IO.NodeOutput(await download_url_to_video_output(response.video.url)) return IO.NodeOutput(await download_url_to_video_output(response.video.url))
@ -611,7 +630,7 @@ class GrokVideoEditNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="GrokVideoEditNode", node_id="GrokVideoEditNode",
display_name="Grok Video Edit", display_name="Grok Video Edit",
category="video/partner/Grok", category="partner/video/Grok",
description="Edit an existing video based on a text prompt.", description="Edit an existing video based on a text prompt.",
inputs=[ inputs=[
IO.Combo.Input("model", options=["grok-imagine-video"]), IO.Combo.Input("model", options=["grok-imagine-video"]),
@ -689,7 +708,7 @@ class GrokVideoReferenceNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="GrokVideoReferenceNode", node_id="GrokVideoReferenceNode",
display_name="Grok Reference-to-Video", display_name="Grok Reference-to-Video",
category="video/partner/Grok", category="partner/video/Grok",
description="Generate video guided by reference images as style and content references.", description="Generate video guided by reference images as style and content references.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -822,7 +841,7 @@ class GrokVideoExtendNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="GrokVideoExtendNode", node_id="GrokVideoExtendNode",
display_name="Grok Video Extend", display_name="Grok Video Extend",
category="video/partner/Grok", category="partner/video/Grok",
description="Extend an existing video with a seamless continuation based on a text prompt.", description="Extend an existing video with a seamless continuation based on a text prompt.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(

View File

@ -71,7 +71,7 @@ class HitPawGeneralImageEnhance(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="HitPawGeneralImageEnhance", node_id="HitPawGeneralImageEnhance",
display_name="HitPaw General Image Enhance", display_name="HitPaw General Image Enhance",
category="image/partner/HitPaw", category="partner/image/HitPaw",
description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. " description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. "
f"Maximum output: {MAX_MP_GENERATIVE} megapixels.", f"Maximum output: {MAX_MP_GENERATIVE} megapixels.",
inputs=[ inputs=[
@ -201,7 +201,7 @@ class HitPawVideoEnhance(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="HitPawVideoEnhance", node_id="HitPawVideoEnhance",
display_name="HitPaw Video Enhance", display_name="HitPaw Video Enhance",
category="video/partner/HitPaw", category="partner/video/HitPaw",
description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. " description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. "
"Prices shown are per second of video.", "Prices shown are per second of video.",
inputs=[ inputs=[

View File

@ -123,7 +123,7 @@ class TencentTextToModelNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TencentTextToModelNode", node_id="TencentTextToModelNode",
display_name="Hunyuan3D: Text to Model", display_name="Hunyuan3D: Text to Model",
category="3d/partner/Tencent", category="partner/3d/Tencent",
essentials_category="3D", essentials_category="3D",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -242,7 +242,7 @@ class TencentImageToModelNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TencentImageToModelNode", node_id="TencentImageToModelNode",
display_name="Hunyuan3D: Image(s) to Model", display_name="Hunyuan3D: Image(s) to Model",
category="3d/partner/Tencent", category="partner/3d/Tencent",
essentials_category="3D", essentials_category="3D",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -415,7 +415,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TencentModelTo3DUVNode", node_id="TencentModelTo3DUVNode",
display_name="Hunyuan3D: Model to UV", display_name="Hunyuan3D: Model to UV",
category="3d/partner/Tencent", category="partner/3d/Tencent",
description="Perform UV unfolding on a 3D model to generate UV texture. " description="Perform UV unfolding on a 3D model to generate UV texture. "
"Input model must have less than 30000 faces.", "Input model must have less than 30000 faces.",
inputs=[ inputs=[
@ -505,7 +505,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Tencent3DTextureEditNode", node_id="Tencent3DTextureEditNode",
display_name="Hunyuan3D: 3D Texture Edit", display_name="Hunyuan3D: 3D Texture Edit",
category="3d/partner/Tencent", category="partner/3d/Tencent",
description="After inputting the 3D model, perform 3D model texture redrawing.", description="After inputting the 3D model, perform 3D model texture redrawing.",
inputs=[ inputs=[
IO.MultiType.Input( IO.MultiType.Input(
@ -594,7 +594,7 @@ class Tencent3DPartNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Tencent3DPartNode", node_id="Tencent3DPartNode",
display_name="Hunyuan3D: 3D Part", display_name="Hunyuan3D: 3D Part",
category="3d/partner/Tencent", category="partner/3d/Tencent",
description="Automatically perform component identification and generation based on the model structure.", description="Automatically perform component identification and generation based on the model structure.",
inputs=[ inputs=[
IO.MultiType.Input( IO.MultiType.Input(
@ -666,7 +666,7 @@ class TencentSmartTopologyNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TencentSmartTopologyNode", node_id="TencentSmartTopologyNode",
display_name="Hunyuan3D: Smart Topology", display_name="Hunyuan3D: Smart Topology",
category="3d/partner/Tencent", category="partner/3d/Tencent",
description="Perform smart retopology on a 3D model. " description="Perform smart retopology on a 3D model. "
"Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.", "Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.",
inputs=[ inputs=[

View File

@ -234,7 +234,7 @@ class IdeogramV1(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="IdeogramV1", node_id="IdeogramV1",
display_name="Ideogram V1", display_name="Ideogram V1",
category="image/partner/Ideogram", category="partner/image/Ideogram",
description="Generates images using the Ideogram V1 model.", description="Generates images using the Ideogram V1 model.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -360,7 +360,7 @@ class IdeogramV2(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="IdeogramV2", node_id="IdeogramV2",
display_name="Ideogram V2", display_name="Ideogram V2",
category="image/partner/Ideogram", category="partner/image/Ideogram",
description="Generates images using the Ideogram V2 model.", description="Generates images using the Ideogram V2 model.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -526,7 +526,7 @@ class IdeogramV3(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="IdeogramV3", node_id="IdeogramV3",
display_name="Ideogram V3", display_name="Ideogram V3",
category="image/partner/Ideogram", category="partner/image/Ideogram",
description="Generates images using the Ideogram V3 model. " description="Generates images using the Ideogram V3 model. "
"Supports both regular image generation from text prompts and image editing with mask.", "Supports both regular image generation from text prompts and image editing with mask.",
inputs=[ inputs=[

View File

@ -642,7 +642,7 @@ class KlingCameraControls(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingCameraControls", node_id="KlingCameraControls",
display_name="Kling Camera Controls", display_name="Kling Camera Controls",
category="video/partner/Kling", category="partner/video/Kling",
description="Allows specifying configuration options for Kling Camera Controls and motion control effects.", description="Allows specifying configuration options for Kling Camera Controls and motion control effects.",
inputs=[ inputs=[
IO.Combo.Input("camera_control_type", options=KlingCameraControlType), IO.Combo.Input("camera_control_type", options=KlingCameraControlType),
@ -762,7 +762,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingTextToVideoNode", node_id="KlingTextToVideoNode",
display_name="Kling Text to Video", display_name="Kling Text to Video",
category="video/partner/Kling", category="partner/video/Kling",
description="Kling Text to Video Node", description="Kling Text to Video Node",
inputs=[ inputs=[
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -849,7 +849,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingOmniProTextToVideoNode", node_id="KlingOmniProTextToVideoNode",
display_name="Kling 3.0 Omni Text to Video", display_name="Kling 3.0 Omni Text to Video",
category="video/partner/Kling", category="partner/video/Kling",
description="Use text prompts to generate videos with the latest Kling model.", description="Use text prompts to generate videos with the latest Kling model.",
inputs=[ inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
@ -998,7 +998,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingOmniProFirstLastFrameNode", node_id="KlingOmniProFirstLastFrameNode",
display_name="Kling 3.0 Omni First-Last-Frame to Video", display_name="Kling 3.0 Omni First-Last-Frame to Video",
category="video/partner/Kling", category="partner/video/Kling",
description="Use a start frame, an optional end frame, or reference images with the latest Kling model.", description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
inputs=[ inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
@ -1205,7 +1205,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingOmniProImageToVideoNode", node_id="KlingOmniProImageToVideoNode",
display_name="Kling 3.0 Omni Image to Video", display_name="Kling 3.0 Omni Image to Video",
category="video/partner/Kling", category="partner/video/Kling",
description="Use up to 7 reference images to generate a video with the latest Kling model.", description="Use up to 7 reference images to generate a video with the latest Kling model.",
inputs=[ inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
@ -1374,7 +1374,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingOmniProVideoToVideoNode", node_id="KlingOmniProVideoToVideoNode",
display_name="Kling 3.0 Omni Video to Video", display_name="Kling 3.0 Omni Video to Video",
category="video/partner/Kling", category="partner/video/Kling",
description="Use a video and up to 4 reference images to generate a video with the latest Kling model.", description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
inputs=[ inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
@ -1485,7 +1485,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingOmniProEditVideoNode", node_id="KlingOmniProEditVideoNode",
display_name="Kling 3.0 Omni Edit Video", display_name="Kling 3.0 Omni Edit Video",
category="video/partner/Kling", category="partner/video/Kling",
essentials_category="Video Generation", essentials_category="Video Generation",
description="Edit an existing video with the latest model from Kling.", description="Edit an existing video with the latest model from Kling.",
inputs=[ inputs=[
@ -1593,7 +1593,7 @@ class OmniProImageNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingOmniProImageNode", node_id="KlingOmniProImageNode",
display_name="Kling 3.0 Omni Image", display_name="Kling 3.0 Omni Image",
category="image/partner/Kling", category="partner/image/Kling",
description="Create or edit images with the latest model from Kling.", description="Create or edit images with the latest model from Kling.",
inputs=[ inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]), IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]),
@ -1721,7 +1721,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingCameraControlT2VNode", node_id="KlingCameraControlT2VNode",
display_name="Kling Text to Video (Camera Control)", display_name="Kling Text to Video (Camera Control)",
category="video/partner/Kling", category="partner/video/Kling",
description="Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text.", description="Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text.",
inputs=[ inputs=[
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -1783,7 +1783,7 @@ class KlingImage2VideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingImage2VideoNode", node_id="KlingImage2VideoNode",
display_name="Kling Image(First Frame) to Video", display_name="Kling Image(First Frame) to Video",
category="video/partner/Kling", category="partner/video/Kling",
inputs=[ inputs=[
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."), IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -1882,7 +1882,7 @@ class KlingCameraControlI2VNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingCameraControlI2VNode", node_id="KlingCameraControlI2VNode",
display_name="Kling Image to Video (Camera Control)", display_name="Kling Image to Video (Camera Control)",
category="video/partner/Kling", category="partner/video/Kling",
description="Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image.", description="Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(
@ -1953,7 +1953,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingStartEndFrameNode", node_id="KlingStartEndFrameNode",
display_name="Kling Start-End Frame to Video", display_name="Kling Start-End Frame to Video",
category="video/partner/Kling", category="partner/video/Kling",
description="Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last.", description="Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(
@ -2047,7 +2047,7 @@ class KlingVideoExtendNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingVideoExtendNode", node_id="KlingVideoExtendNode",
display_name="Kling Video Extend", display_name="Kling Video Extend",
category="video/partner/Kling", category="partner/video/Kling",
description="Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes.", description="Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -2128,7 +2128,7 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingDualCharacterVideoEffectNode", node_id="KlingDualCharacterVideoEffectNode",
display_name="Kling Dual Character Video Effects", display_name="Kling Dual Character Video Effects",
category="video/partner/Kling", category="partner/video/Kling",
description="Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite.", description="Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite.",
inputs=[ inputs=[
IO.Image.Input("image_left", tooltip="Left side image"), IO.Image.Input("image_left", tooltip="Left side image"),
@ -2218,7 +2218,7 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingSingleImageVideoEffectNode", node_id="KlingSingleImageVideoEffectNode",
display_name="Kling Video Effects", display_name="Kling Video Effects",
category="video/partner/Kling", category="partner/video/Kling",
description="Achieve different special effects when generating a video based on the effect_scene.", description="Achieve different special effects when generating a video based on the effect_scene.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(
@ -2291,7 +2291,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingLipSyncAudioToVideoNode", node_id="KlingLipSyncAudioToVideoNode",
display_name="Kling Lip Sync Video with Audio", display_name="Kling Lip Sync Video with Audio",
category="video/partner/Kling", category="partner/video/Kling",
essentials_category="Video Generation", essentials_category="Video Generation",
description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.", description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
inputs=[ inputs=[
@ -2343,7 +2343,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingLipSyncTextToVideoNode", node_id="KlingLipSyncTextToVideoNode",
display_name="Kling Lip Sync Video with Text", display_name="Kling Lip Sync Video with Text",
category="video/partner/Kling", category="partner/video/Kling",
description="Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.", description="Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
inputs=[ inputs=[
IO.Video.Input("video"), IO.Video.Input("video"),
@ -2411,7 +2411,7 @@ class KlingVirtualTryOnNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingVirtualTryOnNode", node_id="KlingVirtualTryOnNode",
display_name="Kling Virtual Try On", display_name="Kling Virtual Try On",
category="image/partner/Kling", category="partner/image/Kling",
description="Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background.", description="Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background.",
inputs=[ inputs=[
IO.Image.Input("human_image"), IO.Image.Input("human_image"),
@ -2478,7 +2478,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingImageGenerationNode", node_id="KlingImageGenerationNode",
display_name="Kling 3.0 Image", display_name="Kling 3.0 Image",
category="image/partner/Kling", category="partner/image/Kling",
description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.", description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.",
inputs=[ inputs=[
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -2615,7 +2615,7 @@ class TextToVideoWithAudio(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingTextToVideoWithAudio", node_id="KlingTextToVideoWithAudio",
display_name="Kling 2.6 Text to Video with Audio", display_name="Kling 2.6 Text to Video with Audio",
category="video/partner/Kling", category="partner/video/Kling",
inputs=[ inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]), IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
@ -2683,7 +2683,7 @@ class ImageToVideoWithAudio(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingImageToVideoWithAudio", node_id="KlingImageToVideoWithAudio",
display_name="Kling 2.6 Image(First Frame) to Video with Audio", display_name="Kling 2.6 Image(First Frame) to Video with Audio",
category="video/partner/Kling", category="partner/video/Kling",
inputs=[ inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]), IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.Image.Input("start_frame"), IO.Image.Input("start_frame"),
@ -2753,7 +2753,7 @@ class MotionControl(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingMotionControl", node_id="KlingMotionControl",
display_name="Kling Motion Control", display_name="Kling Motion Control",
category="video/partner/Kling", category="partner/video/Kling",
inputs=[ inputs=[
IO.String.Input("prompt", multiline=True), IO.String.Input("prompt", multiline=True),
IO.Image.Input("reference_image"), IO.Image.Input("reference_image"),
@ -2854,7 +2854,7 @@ class KlingVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingVideoNode", node_id="KlingVideoNode",
display_name="Kling 3.0 Video", display_name="Kling 3.0 Video",
category="video/partner/Kling", category="partner/video/Kling",
description="Generate videos with Kling V3. " description="Generate videos with Kling V3. "
"Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.", "Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.",
inputs=[ inputs=[
@ -3077,7 +3077,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingFirstLastFrameNode", node_id="KlingFirstLastFrameNode",
display_name="Kling 3.0 First-Last-Frame to Video", display_name="Kling 3.0 First-Last-Frame to Video",
category="video/partner/Kling", category="partner/video/Kling",
description="Generate videos with Kling V3 using first and last frames.", description="Generate videos with Kling V3 using first and last frames.",
inputs=[ inputs=[
IO.String.Input("prompt", multiline=True, default=""), IO.String.Input("prompt", multiline=True, default=""),
@ -3202,7 +3202,7 @@ class KlingAvatarNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="KlingAvatarNode", node_id="KlingAvatarNode",
display_name="Kling Avatar 2.0", display_name="Kling Avatar 2.0",
category="video/partner/Kling", category="partner/video/Kling",
description="Generate broadcast-style digital human videos from a single photo and an audio file.", description="Generate broadcast-style digital human videos from a single photo and an audio file.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(

View File

@ -106,7 +106,7 @@ class Krea2ImageNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Krea2ImageNode", node_id="Krea2ImageNode",
display_name="Krea 2 Image", display_name="Krea 2 Image",
category="image/partner/Krea", category="partner/image/Krea",
description=( description=(
"Generate images via Krea 2 — pick Medium (expressive illustrations) or " "Generate images via Krea 2 — pick Medium (expressive illustrations) or "
"Large (expressive photorealism). Supports an optional moodboard and up " "Large (expressive photorealism). Supports an optional moodboard and up "
@ -229,7 +229,7 @@ class Krea2StyleReferenceNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Krea2StyleReferenceNode", node_id="Krea2StyleReferenceNode",
display_name="Krea 2 Style Reference", display_name="Krea 2 Style Reference",
category="image/partner/Krea", category="partner/image/Krea",
description=( description=(
"Add an image style reference to a Krea 2 generation. Chain multiple Krea 2 " "Add an image style reference to a Krea 2 generation. Chain multiple Krea 2 "
"Style Reference nodes (max 10) and feed the final `style_reference` output " "Style Reference nodes (max 10) and feed the final `style_reference` output "

View File

@ -50,7 +50,7 @@ class TextToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="LtxvApiTextToVideo", node_id="LtxvApiTextToVideo",
display_name="LTXV Text To Video", display_name="LTXV Text To Video",
category="video/partner/LTXV", category="partner/video/LTXV",
description="Professional-quality videos with customizable duration and resolution.", description="Professional-quality videos with customizable duration and resolution.",
inputs=[ inputs=[
IO.Combo.Input("model", options=list(MODELS_MAP.keys())), IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
@ -127,7 +127,7 @@ class ImageToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="LtxvApiImageToVideo", node_id="LtxvApiImageToVideo",
display_name="LTXV Image To Video", display_name="LTXV Image To Video",
category="video/partner/LTXV", category="partner/video/LTXV",
description="Professional-quality videos with customizable duration and resolution based on start image.", description="Professional-quality videos with customizable duration and resolution based on start image.",
inputs=[ inputs=[
IO.Image.Input("image", tooltip="First frame to be used for the video."), IO.Image.Input("image", tooltip="First frame to be used for the video."),

View File

@ -46,7 +46,7 @@ class LumaReferenceNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="LumaReferenceNode", node_id="LumaReferenceNode",
display_name="Luma Reference", display_name="Luma Reference",
category="image/partner/Luma", category="partner/image/Luma",
description="Holds an image and weight for use with Luma Generate Image node.", description="Holds an image and weight for use with Luma Generate Image node.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(
@ -85,7 +85,7 @@ class LumaConceptsNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="LumaConceptsNode", node_id="LumaConceptsNode",
display_name="Luma Concepts", display_name="Luma Concepts",
category="video/partner/Luma", category="partner/video/Luma",
description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.", description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -134,7 +134,7 @@ class LumaImageGenerationNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="LumaImageNode", node_id="LumaImageNode",
display_name="Luma Text to Image", display_name="Luma Text to Image",
category="image/partner/Luma", category="partner/image/Luma",
description="Generates images synchronously based on prompt and aspect ratio.", description="Generates images synchronously based on prompt and aspect ratio.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -278,7 +278,7 @@ class LumaImageModifyNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="LumaImageModifyNode", node_id="LumaImageModifyNode",
display_name="Luma Image to Image", display_name="Luma Image to Image",
category="image/partner/Luma", category="partner/image/Luma",
description="Modifies images synchronously based on prompt and aspect ratio.", description="Modifies images synchronously based on prompt and aspect ratio.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(
@ -371,7 +371,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="LumaVideoNode", node_id="LumaVideoNode",
display_name="Luma Text to Video", display_name="Luma Text to Video",
category="video/partner/Luma", category="partner/video/Luma",
description="Generates videos synchronously based on prompt and output_size.", description="Generates videos synchronously based on prompt and output_size.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -472,7 +472,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="LumaImageToVideoNode", node_id="LumaImageToVideoNode",
display_name="Luma Image to Video", display_name="Luma Image to Video",
category="video/partner/Luma", category="partner/video/Luma",
description="Generates videos synchronously based on prompt, input images, and output_size.", description="Generates videos synchronously based on prompt, input images, and output_size.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -724,7 +724,7 @@ class LumaImageNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="LumaImageNode2", node_id="LumaImageNode2",
display_name="Luma UNI-1 Image", display_name="Luma UNI-1 Image",
category="image/partner/Luma", category="partner/image/Luma",
description="Generate images from text using the Luma UNI-1 model.", description="Generate images from text using the Luma UNI-1 model.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -853,7 +853,7 @@ class LumaImageEditNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="LumaImageEditNode2", node_id="LumaImageEditNode2",
display_name="Luma UNI-1 Image Edit", display_name="Luma UNI-1 Image Edit",
category="image/partner/Luma", category="partner/image/Luma",
description="Edit an existing image with a text prompt using the Luma UNI-1 model.", description="Edit an existing image with a text prompt using the Luma UNI-1 model.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(

View File

@ -61,7 +61,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="MagnificImageUpscalerCreativeNode", node_id="MagnificImageUpscalerCreativeNode",
display_name="Magnific Image Upscale (Creative)", display_name="Magnific Image Upscale (Creative)",
category="image/partner/Magnific", category="partner/image/Magnific",
description="Promptguided enhancement, stylization, and 2x/4x/8x/16x upscaling. " description="Promptguided enhancement, stylization, and 2x/4x/8x/16x upscaling. "
"Maximum output: 25.3 megapixels.", "Maximum output: 25.3 megapixels.",
inputs=[ inputs=[
@ -240,7 +240,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="MagnificImageUpscalerPreciseV2Node", node_id="MagnificImageUpscalerPreciseV2Node",
display_name="Magnific Image Upscale (Precise V2)", display_name="Magnific Image Upscale (Precise V2)",
category="image/partner/Magnific", category="partner/image/Magnific",
description="High-fidelity upscaling with fine control over sharpness, grain, and detail. " description="High-fidelity upscaling with fine control over sharpness, grain, and detail. "
"Maximum output: 10060×10060 pixels.", "Maximum output: 10060×10060 pixels.",
inputs=[ inputs=[
@ -400,7 +400,7 @@ class MagnificImageStyleTransferNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="MagnificImageStyleTransferNode", node_id="MagnificImageStyleTransferNode",
display_name="Magnific Image Style Transfer", display_name="Magnific Image Style Transfer",
category="image/partner/Magnific", category="partner/image/Magnific",
description="Transfer the style from a reference image to your input image.", description="Transfer the style from a reference image to your input image.",
inputs=[ inputs=[
IO.Image.Input("image", tooltip="The image to apply style transfer to."), IO.Image.Input("image", tooltip="The image to apply style transfer to."),
@ -549,7 +549,7 @@ class MagnificImageRelightNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="MagnificImageRelightNode", node_id="MagnificImageRelightNode",
display_name="Magnific Image Relight", display_name="Magnific Image Relight",
category="image/partner/Magnific", category="partner/image/Magnific",
description="Relight an image with lighting adjustments and optional reference-based light transfer.", description="Relight an image with lighting adjustments and optional reference-based light transfer.",
inputs=[ inputs=[
IO.Image.Input("image", tooltip="The image to relight."), IO.Image.Input("image", tooltip="The image to relight."),
@ -789,7 +789,7 @@ class MagnificImageSkinEnhancerNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="MagnificImageSkinEnhancerNode", node_id="MagnificImageSkinEnhancerNode",
display_name="Magnific Image Skin Enhancer", display_name="Magnific Image Skin Enhancer",
category="image/partner/Magnific", category="partner/image/Magnific",
description="Skin enhancement for portraits with multiple processing modes.", description="Skin enhancement for portraits with multiple processing modes.",
inputs=[ inputs=[
IO.Image.Input("image", tooltip="The portrait image to enhance."), IO.Image.Input("image", tooltip="The portrait image to enhance."),

View File

@ -33,7 +33,7 @@ class MeshyTextToModelNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="MeshyTextToModelNode", node_id="MeshyTextToModelNode",
display_name="Meshy: Text to Model", display_name="Meshy: Text to Model",
category="3d/partner/Meshy", category="partner/3d/Meshy",
inputs=[ inputs=[
IO.Combo.Input("model", options=["latest"]), IO.Combo.Input("model", options=["latest"]),
IO.String.Input("prompt", multiline=True, default=""), IO.String.Input("prompt", multiline=True, default=""),
@ -145,7 +145,7 @@ class MeshyRefineNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="MeshyRefineNode", node_id="MeshyRefineNode",
display_name="Meshy: Refine Draft Model", display_name="Meshy: Refine Draft Model",
category="3d/partner/Meshy", category="partner/3d/Meshy",
description="Refine a previously created draft model.", description="Refine a previously created draft model.",
inputs=[ inputs=[
IO.Combo.Input("model", options=["latest"]), IO.Combo.Input("model", options=["latest"]),
@ -240,7 +240,7 @@ class MeshyImageToModelNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="MeshyImageToModelNode", node_id="MeshyImageToModelNode",
display_name="Meshy: Image to Model", display_name="Meshy: Image to Model",
category="3d/partner/Meshy", category="partner/3d/Meshy",
inputs=[ inputs=[
IO.Combo.Input("model", options=["latest"]), IO.Combo.Input("model", options=["latest"]),
IO.Image.Input("image"), IO.Image.Input("image"),
@ -405,7 +405,7 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="MeshyMultiImageToModelNode", node_id="MeshyMultiImageToModelNode",
display_name="Meshy: Multi-Image to Model", display_name="Meshy: Multi-Image to Model",
category="3d/partner/Meshy", category="partner/3d/Meshy",
inputs=[ inputs=[
IO.Combo.Input("model", options=["latest"]), IO.Combo.Input("model", options=["latest"]),
IO.Autogrow.Input( IO.Autogrow.Input(
@ -575,7 +575,7 @@ class MeshyRigModelNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="MeshyRigModelNode", node_id="MeshyRigModelNode",
display_name="Meshy: Rig Model", display_name="Meshy: Rig Model",
category="3d/partner/Meshy", category="partner/3d/Meshy",
description="Provides a rigged character in standard formats. " description="Provides a rigged character in standard formats. "
"Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, " "Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, "
"or humanoid assets with unclear limb and body structure.", "or humanoid assets with unclear limb and body structure.",
@ -656,7 +656,7 @@ class MeshyAnimateModelNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="MeshyAnimateModelNode", node_id="MeshyAnimateModelNode",
display_name="Meshy: Animate Model", display_name="Meshy: Animate Model",
category="3d/partner/Meshy", category="partner/3d/Meshy",
description="Apply a specific animation action to a previously rigged character.", description="Apply a specific animation action to a previously rigged character.",
inputs=[ inputs=[
IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"), IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"),
@ -722,7 +722,7 @@ class MeshyTextureNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="MeshyTextureNode", node_id="MeshyTextureNode",
display_name="Meshy: Texture Model", display_name="Meshy: Texture Model",
category="3d/partner/Meshy", category="partner/3d/Meshy",
inputs=[ inputs=[
IO.Combo.Input("model", options=["latest"]), IO.Combo.Input("model", options=["latest"]),
IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"),

View File

@ -101,7 +101,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="MinimaxTextToVideoNode", node_id="MinimaxTextToVideoNode",
display_name="MiniMax Text to Video", display_name="MiniMax Text to Video",
category="video/partner/MiniMax", category="partner/video/MiniMax",
description="Generates videos synchronously based on a prompt, and optional parameters.", description="Generates videos synchronously based on a prompt, and optional parameters.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -163,7 +163,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="MinimaxImageToVideoNode", node_id="MinimaxImageToVideoNode",
display_name="MiniMax Image to Video", display_name="MiniMax Image to Video",
category="video/partner/MiniMax", category="partner/video/MiniMax",
description="Generates videos synchronously based on an image and prompt, and optional parameters.", description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(
@ -230,7 +230,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="MinimaxSubjectToVideoNode", node_id="MinimaxSubjectToVideoNode",
display_name="MiniMax Subject to Video", display_name="MiniMax Subject to Video",
category="video/partner/MiniMax", category="partner/video/MiniMax",
description="Generates videos synchronously based on an image and prompt, and optional parameters.", description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(
@ -294,7 +294,7 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="MinimaxHailuoVideoNode", node_id="MinimaxHailuoVideoNode",
display_name="MiniMax Hailuo Video", display_name="MiniMax Hailuo Video",
category="video/partner/MiniMax", category="partner/video/MiniMax",
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.", description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(

View File

@ -99,7 +99,7 @@ class OpenAIDalle2(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="OpenAIDalle2", node_id="OpenAIDalle2",
display_name="OpenAI DALL·E 2", display_name="OpenAI DALL·E 2",
category="image/partner/OpenAI", category="partner/image/OpenAI",
description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.", description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -249,7 +249,7 @@ class OpenAIDalle3(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="OpenAIDalle3", node_id="OpenAIDalle3",
display_name="OpenAI DALL·E 3", display_name="OpenAI DALL·E 3",
category="image/partner/OpenAI", category="partner/image/OpenAI",
description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.", description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -371,7 +371,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="OpenAIGPTImage1", node_id="OpenAIGPTImage1",
display_name="OpenAI GPT Image 2", display_name="OpenAI GPT Image 2",
category="image/partner/OpenAI", category="partner/image/OpenAI",
description="Generates images synchronously via OpenAI's GPT Image endpoint.", description="Generates images synchronously via OpenAI's GPT Image endpoint.",
is_deprecated=True, is_deprecated=True,
inputs=[ inputs=[
@ -695,7 +695,7 @@ class OpenAIGPTImageNodeV2(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="OpenAIGPTImageNodeV2", node_id="OpenAIGPTImageNodeV2",
display_name="OpenAI GPT Image 2", display_name="OpenAI GPT Image 2",
category="image/partner/OpenAI", category="partner/image/OpenAI",
description="Generates images via OpenAI's GPT Image endpoint.", description="Generates images via OpenAI's GPT Image endpoint.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -962,7 +962,7 @@ class OpenAIChatNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="OpenAIChatNode", node_id="OpenAIChatNode",
display_name="OpenAI ChatGPT", display_name="OpenAI ChatGPT",
category="text/partner/OpenAI", category="partner/text/OpenAI",
essentials_category="Text Generation", essentials_category="Text Generation",
description="Generate text responses from an OpenAI model.", description="Generate text responses from an OpenAI model.",
inputs=[ inputs=[
@ -1201,7 +1201,7 @@ class OpenAIInputFiles(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="OpenAIInputFiles", node_id="OpenAIInputFiles",
display_name="OpenAI ChatGPT Input Files", display_name="OpenAI ChatGPT Input Files",
category="text/partner/OpenAI", category="partner/text/OpenAI",
description="Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes.", description="Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -1248,7 +1248,7 @@ class OpenAIChatConfig(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="OpenAIChatConfig", node_id="OpenAIChatConfig",
display_name="OpenAI ChatGPT Advanced Options", display_name="OpenAI ChatGPT Advanced Options",
category="text/partner/OpenAI", category="partner/text/OpenAI",
description="Allows specifying advanced configuration options for the OpenAI Chat Nodes.", description="Allows specifying advanced configuration options for the OpenAI Chat Nodes.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(

View File

@ -265,7 +265,7 @@ class OpenRouterLLMNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="OpenRouterLLMNode", node_id="OpenRouterLLMNode",
display_name="OpenRouter LLM", display_name="OpenRouter LLM",
category="text/partner/OpenRouter", category="partner/text/OpenRouter",
essentials_category="Text Generation", essentials_category="Text Generation",
description=( description=(
"Generate text responses through OpenRouter. Routes to a curated set of popular " "Generate text responses through OpenRouter. Routes to a curated set of popular "

View File

@ -53,7 +53,7 @@ class PixverseTemplateNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="PixverseTemplateNode", node_id="PixverseTemplateNode",
display_name="PixVerse Template", display_name="PixVerse Template",
category="video/partner/PixVerse", category="partner/video/PixVerse",
inputs=[ inputs=[
IO.Combo.Input("template", options=list(pixverse_templates.keys())), IO.Combo.Input("template", options=list(pixverse_templates.keys())),
], ],
@ -74,7 +74,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="PixverseTextToVideoNode", node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video", display_name="PixVerse Text to Video",
category="video/partner/PixVerse", category="partner/video/PixVerse",
description="Generates videos based on prompt and output_size.", description="Generates videos based on prompt and output_size.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -192,7 +192,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="PixverseImageToVideoNode", node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video", display_name="PixVerse Image to Video",
category="video/partner/PixVerse", category="partner/video/PixVerse",
description="Generates videos based on prompt and output_size.", description="Generates videos based on prompt and output_size.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -310,7 +310,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="PixverseTransitionVideoNode", node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video", display_name="PixVerse Transition Video",
category="video/partner/PixVerse", category="partner/video/PixVerse",
description="Generates videos based on prompt and output_size.", description="Generates videos based on prompt and output_size.",
inputs=[ inputs=[
IO.Image.Input("first_frame"), IO.Image.Input("first_frame"),

View File

@ -62,7 +62,7 @@ class QuiverTextToSVGNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="QuiverTextToSVGNode", node_id="QuiverTextToSVGNode",
display_name="Quiver Text to SVG", display_name="Quiver Text to SVG",
category="image/partner/Quiver", category="partner/image/Quiver",
description="Generate an SVG from a text prompt using Quiver AI.", description="Generate an SVG from a text prompt using Quiver AI.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -177,7 +177,7 @@ class QuiverImageToSVGNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="QuiverImageToSVGNode", node_id="QuiverImageToSVGNode",
display_name="Quiver Image to SVG", display_name="Quiver Image to SVG",
category="image/partner/Quiver", category="partner/image/Quiver",
description="Vectorize a raster image into SVG using Quiver AI.", description="Vectorize a raster image into SVG using Quiver AI.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(

View File

@ -178,7 +178,7 @@ class RecraftColorRGBNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RecraftColorRGB", node_id="RecraftColorRGB",
display_name="Recraft Color RGB", display_name="Recraft Color RGB",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Create Recraft Color by choosing specific RGB values.", description="Create Recraft Color by choosing specific RGB values.",
inputs=[ inputs=[
IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."), IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."),
@ -204,7 +204,7 @@ class RecraftControlsNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RecraftControls", node_id="RecraftControls",
display_name="Recraft Controls", display_name="Recraft Controls",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Create Recraft Controls for customizing Recraft generation.", description="Create Recraft Controls for customizing Recraft generation.",
inputs=[ inputs=[
IO.Custom(RecraftIO.COLOR).Input("colors", optional=True), IO.Custom(RecraftIO.COLOR).Input("colors", optional=True),
@ -228,7 +228,7 @@ class RecraftStyleV3RealisticImageNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RecraftStyleV3RealisticImage", node_id="RecraftStyleV3RealisticImage",
display_name="Recraft Style - Realistic Image", display_name="Recraft Style - Realistic Image",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Select realistic_image style and optional substyle.", description="Select realistic_image style and optional substyle.",
inputs=[ inputs=[
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
@ -253,7 +253,7 @@ class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode):
return IO.Schema( return IO.Schema(
node_id="RecraftStyleV3DigitalIllustration", node_id="RecraftStyleV3DigitalIllustration",
display_name="Recraft Style - Digital Illustration", display_name="Recraft Style - Digital Illustration",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Select realistic_image style and optional substyle.", description="Select realistic_image style and optional substyle.",
inputs=[ inputs=[
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
@ -272,7 +272,7 @@ class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode):
return IO.Schema( return IO.Schema(
node_id="RecraftStyleV3VectorIllustrationNode", node_id="RecraftStyleV3VectorIllustrationNode",
display_name="Recraft Style - Realistic Image", display_name="Recraft Style - Realistic Image",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Select realistic_image style and optional substyle.", description="Select realistic_image style and optional substyle.",
inputs=[ inputs=[
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
@ -291,7 +291,7 @@ class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode):
return IO.Schema( return IO.Schema(
node_id="RecraftStyleV3LogoRaster", node_id="RecraftStyleV3LogoRaster",
display_name="Recraft Style - Logo Raster", display_name="Recraft Style - Logo Raster",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Select realistic_image style and optional substyle.", description="Select realistic_image style and optional substyle.",
inputs=[ inputs=[
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)), IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)),
@ -308,7 +308,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RecraftStyleV3InfiniteStyleLibrary", node_id="RecraftStyleV3InfiniteStyleLibrary",
display_name="Recraft Style - Infinite Style Library", display_name="Recraft Style - Infinite Style Library",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.", description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.",
inputs=[ inputs=[
IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."), IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."),
@ -331,7 +331,7 @@ class RecraftCreateStyleNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RecraftCreateStyleNode", node_id="RecraftCreateStyleNode",
display_name="Recraft Create Style", display_name="Recraft Create Style",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Create a custom style from reference images. " description="Create a custom style from reference images. "
"Upload 1-5 images to use as style references. " "Upload 1-5 images to use as style references. "
"Total size of all images is limited to 5 MB.", "Total size of all images is limited to 5 MB.",
@ -400,7 +400,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RecraftTextToImageNode", node_id="RecraftTextToImageNode",
display_name="Recraft Text to Image", display_name="Recraft Text to Image",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Generates images synchronously based on prompt and resolution.", description="Generates images synchronously based on prompt and resolution.",
inputs=[ inputs=[
IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."),
@ -512,7 +512,7 @@ class RecraftImageToImageNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RecraftImageToImageNode", node_id="RecraftImageToImageNode",
display_name="Recraft Image to Image", display_name="Recraft Image to Image",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Modify image based on prompt and strength.", description="Modify image based on prompt and strength.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -630,7 +630,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RecraftImageInpaintingNode", node_id="RecraftImageInpaintingNode",
display_name="Recraft Image Inpainting", display_name="Recraft Image Inpainting",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Modify image based on prompt and mask.", description="Modify image based on prompt and mask.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -732,7 +732,7 @@ class RecraftTextToVectorNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RecraftTextToVectorNode", node_id="RecraftTextToVectorNode",
display_name="Recraft Text to Vector", display_name="Recraft Text to Vector",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Generates SVG synchronously based on prompt and resolution.", description="Generates SVG synchronously based on prompt and resolution.",
inputs=[ inputs=[
IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True), IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True),
@ -832,7 +832,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RecraftVectorizeImageNode", node_id="RecraftVectorizeImageNode",
display_name="Recraft Vectorize Image", display_name="Recraft Vectorize Image",
category="image/partner/Recraft", category="partner/image/Recraft",
essentials_category="Image Tools", essentials_category="Image Tools",
description="Generates SVG synchronously from an input image.", description="Generates SVG synchronously from an input image.",
inputs=[ inputs=[
@ -876,7 +876,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RecraftReplaceBackgroundNode", node_id="RecraftReplaceBackgroundNode",
display_name="Recraft Replace Background", display_name="Recraft Replace Background",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Replace background on image, based on provided prompt.", description="Replace background on image, based on provided prompt.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -963,7 +963,7 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RecraftRemoveBackgroundNode", node_id="RecraftRemoveBackgroundNode",
display_name="Recraft Remove Background", display_name="Recraft Remove Background",
category="image/partner/Recraft", category="partner/image/Recraft",
essentials_category="Image Tools", essentials_category="Image Tools",
description="Remove background from image, and return processed image and mask.", description="Remove background from image, and return processed image and mask.",
inputs=[ inputs=[
@ -1012,7 +1012,7 @@ class RecraftCrispUpscaleNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RecraftCrispUpscaleNode", node_id="RecraftCrispUpscaleNode",
display_name="Recraft Crisp Upscale Image", display_name="Recraft Crisp Upscale Image",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Upscale image synchronously.\n" description="Upscale image synchronously.\n"
"Enhances a given raster image using crisp upscale tool, " "Enhances a given raster image using crisp upscale tool, "
"increasing image resolution, making the image sharper and cleaner.", "increasing image resolution, making the image sharper and cleaner.",
@ -1058,7 +1058,7 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
return IO.Schema( return IO.Schema(
node_id="RecraftCreativeUpscaleNode", node_id="RecraftCreativeUpscaleNode",
display_name="Recraft Creative Upscale Image", display_name="Recraft Creative Upscale Image",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Upscale image synchronously.\n" description="Upscale image synchronously.\n"
"Enhances a given raster image using creative upscale tool, " "Enhances a given raster image using creative upscale tool, "
"boosting resolution with a focus on refining small details and faces.", "boosting resolution with a focus on refining small details and faces.",
@ -1086,7 +1086,7 @@ class RecraftV4TextToImageNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RecraftV4TextToImageNode", node_id="RecraftV4TextToImageNode",
display_name="Recraft V4 Text to Image", display_name="Recraft V4 Text to Image",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Generates images using Recraft V4 or V4 Pro models.", description="Generates images using Recraft V4 or V4 Pro models.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -1210,7 +1210,7 @@ class RecraftV4TextToVectorNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RecraftV4TextToVectorNode", node_id="RecraftV4TextToVectorNode",
display_name="Recraft V4 Text to Vector", display_name="Recraft V4 Text to Vector",
category="image/partner/Recraft", category="partner/image/Recraft",
description="Generates SVG using Recraft V4 or V4 Pro models.", description="Generates SVG using Recraft V4 or V4 Pro models.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(

View File

@ -109,7 +109,7 @@ class ReveImageCreateNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ReveImageCreateNode", node_id="ReveImageCreateNode",
display_name="Reve Image Create", display_name="Reve Image Create",
category="image/partner/Reve", category="partner/image/Reve",
description="Generate images from text descriptions using Reve.", description="Generate images from text descriptions using Reve.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -200,7 +200,7 @@ class ReveImageEditNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ReveImageEditNode", node_id="ReveImageEditNode",
display_name="Reve Image Edit", display_name="Reve Image Edit",
category="image/partner/Reve", category="partner/image/Reve",
description="Edit images using natural language instructions with Reve.", description="Edit images using natural language instructions with Reve.",
inputs=[ inputs=[
IO.Image.Input("image", tooltip="The image to edit."), IO.Image.Input("image", tooltip="The image to edit."),
@ -300,7 +300,7 @@ class ReveImageRemixNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ReveImageRemixNode", node_id="ReveImageRemixNode",
display_name="Reve Image Remix", display_name="Reve Image Remix",
category="image/partner/Reve", category="partner/image/Reve",
description="Combine reference images with text prompts to create new images using Reve.", description="Combine reference images with text prompts to create new images using Reve.",
inputs=[ inputs=[
IO.Autogrow.Input( IO.Autogrow.Input(

View File

@ -230,7 +230,7 @@ class Rodin3D_Regular(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Rodin3D_Regular", node_id="Rodin3D_Regular",
display_name="Rodin 3D Generate - Regular Generate", display_name="Rodin 3D Generate - Regular Generate",
category="3d/partner/Rodin", category="partner/3d/Rodin",
description=cleandoc(cls.__doc__ or ""), description=cleandoc(cls.__doc__ or ""),
inputs=[ inputs=[
IO.Image.Input("Images"), IO.Image.Input("Images"),
@ -289,7 +289,7 @@ class Rodin3D_Detail(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Rodin3D_Detail", node_id="Rodin3D_Detail",
display_name="Rodin 3D Generate - Detail Generate", display_name="Rodin 3D Generate - Detail Generate",
category="3d/partner/Rodin", category="partner/3d/Rodin",
description=cleandoc(cls.__doc__ or ""), description=cleandoc(cls.__doc__ or ""),
inputs=[ inputs=[
IO.Image.Input("Images"), IO.Image.Input("Images"),
@ -348,7 +348,7 @@ class Rodin3D_Smooth(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Rodin3D_Smooth", node_id="Rodin3D_Smooth",
display_name="Rodin 3D Generate - Smooth Generate", display_name="Rodin 3D Generate - Smooth Generate",
category="3d/partner/Rodin", category="partner/3d/Rodin",
description=cleandoc(cls.__doc__ or ""), description=cleandoc(cls.__doc__ or ""),
inputs=[ inputs=[
IO.Image.Input("Images"), IO.Image.Input("Images"),
@ -406,7 +406,7 @@ class Rodin3D_Sketch(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Rodin3D_Sketch", node_id="Rodin3D_Sketch",
display_name="Rodin 3D Generate - Sketch Generate", display_name="Rodin 3D Generate - Sketch Generate",
category="3d/partner/Rodin", category="partner/3d/Rodin",
description=cleandoc(cls.__doc__ or ""), description=cleandoc(cls.__doc__ or ""),
inputs=[ inputs=[
IO.Image.Input("Images"), IO.Image.Input("Images"),
@ -468,7 +468,7 @@ class Rodin3D_Gen2(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Rodin3D_Gen2", node_id="Rodin3D_Gen2",
display_name="Rodin 3D Generate - Gen-2 Generate", display_name="Rodin 3D Generate - Gen-2 Generate",
category="3d/partner/Rodin", category="partner/3d/Rodin",
description=cleandoc(cls.__doc__ or ""), description=cleandoc(cls.__doc__ or ""),
inputs=[ inputs=[
IO.Image.Input("Images"), IO.Image.Input("Images"),
@ -941,7 +941,7 @@ class Rodin3D_Gen25_Image(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Rodin3D_Gen25_Image", node_id="Rodin3D_Gen25_Image",
display_name="Rodin 3D Gen-2.5 - Image to 3D", display_name="Rodin 3D Gen-2.5 - Image to 3D",
category="3d/partner/Rodin", category="partner/3d/Rodin",
description=( description=(
"Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. " "Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. "
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost." "Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
@ -1035,7 +1035,7 @@ class Rodin3D_Gen25_Text(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Rodin3D_Gen25_Text", node_id="Rodin3D_Gen25_Text",
display_name="Rodin 3D Gen-2.5 - Text to 3D", display_name="Rodin 3D Gen-2.5 - Text to 3D",
category="3d/partner/Rodin", category="partner/3d/Rodin",
description=( description=(
"Generate a 3D model from a text prompt via Rodin Gen-2.5. " "Generate a 3D model from a text prompt via Rodin Gen-2.5. "
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost." "Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."

View File

@ -140,7 +140,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RunwayImageToVideoNodeGen3a", node_id="RunwayImageToVideoNodeGen3a",
display_name="Runway Image to Video (Gen3a Turbo)", display_name="Runway Image to Video (Gen3a Turbo)",
category="video/partner/Runway", category="partner/video/Runway",
description="Generate a video from a single starting frame using Gen3a Turbo model. " description="Generate a video from a single starting frame using Gen3a Turbo model. "
"Before diving in, review these best practices to ensure that " "Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: " "your input selections will set your generation up for success: "
@ -234,7 +234,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RunwayImageToVideoNodeGen4", node_id="RunwayImageToVideoNodeGen4",
display_name="Runway Image to Video (Gen4 Turbo)", display_name="Runway Image to Video (Gen4 Turbo)",
category="video/partner/Runway", category="partner/video/Runway",
description="Generate a video from a single starting frame using Gen4 Turbo model. " description="Generate a video from a single starting frame using Gen4 Turbo model. "
"Before diving in, review these best practices to ensure that " "Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: " "your input selections will set your generation up for success: "
@ -329,7 +329,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RunwayFirstLastFrameNode", node_id="RunwayFirstLastFrameNode",
display_name="Runway First-Last-Frame to Video", display_name="Runway First-Last-Frame to Video",
category="video/partner/Runway", category="partner/video/Runway",
description="Upload first and last keyframes, draft a prompt, and generate a video. " description="Upload first and last keyframes, draft a prompt, and generate a video. "
"More complex transitions, such as cases where the Last frame is completely different " "More complex transitions, such as cases where the Last frame is completely different "
"from the First frame, may benefit from the longer 10s duration. " "from the First frame, may benefit from the longer 10s duration. "
@ -440,7 +440,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RunwayTextToImageNode", node_id="RunwayTextToImageNode",
display_name="Runway Text to Image", display_name="Runway Text to Image",
category="image/partner/Runway", category="partner/image/Runway",
description="Generate an image from a text prompt using Runway's Gen 4 model. " description="Generate an image from a text prompt using Runway's Gen 4 model. "
"You can also include reference image to guide the generation.", "You can also include reference image to guide the generation.",
inputs=[ inputs=[

View File

@ -34,7 +34,7 @@ class SoniloVideoToMusic(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="SoniloVideoToMusic", node_id="SoniloVideoToMusic",
display_name="Sonilo Video to Music", display_name="Sonilo Video to Music",
category="audio/partner/Sonilo", category="partner/audio/Sonilo",
description="Generate music from video content using Sonilo's AI model. " description="Generate music from video content using Sonilo's AI model. "
"Analyzes the video and creates matching music.", "Analyzes the video and creates matching music.",
inputs=[ inputs=[
@ -99,7 +99,7 @@ class SoniloTextToMusic(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="SoniloTextToMusic", node_id="SoniloTextToMusic",
display_name="Sonilo Text to Music", display_name="Sonilo Text to Music",
category="audio/partner/Sonilo", category="partner/audio/Sonilo",
description="Generate music from a text prompt using Sonilo's AI model. " description="Generate music from a text prompt using Sonilo's AI model. "
"Leave duration at 0 to let the model infer it from the prompt.", "Leave duration at 0 to let the model infer it from the prompt.",
inputs=[ inputs=[

View File

@ -34,7 +34,7 @@ class OpenAIVideoSora2(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="OpenAIVideoSora2", node_id="OpenAIVideoSora2",
display_name="OpenAI Sora - Video (DEPRECATED)", display_name="OpenAI Sora - Video (DEPRECATED)",
category="video/partner/Sora", category="partner/video/Sora",
description=( description=(
"OpenAI video and audio generation.\n\n" "OpenAI video and audio generation.\n\n"
"DEPRECATION NOTICE: OpenAI will stop serving the Sora v2 API in September 2026. " "DEPRECATION NOTICE: OpenAI will stop serving the Sora v2 API in September 2026. "

View File

@ -62,7 +62,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="StabilityStableImageUltraNode", node_id="StabilityStableImageUltraNode",
display_name="Stability AI Stable Image Ultra", display_name="Stability AI Stable Image Ultra",
category="image/partner/Stability AI", category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""), description=cleandoc(cls.__doc__ or ""),
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -197,7 +197,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="StabilityStableImageSD_3_5Node", node_id="StabilityStableImageSD_3_5Node",
display_name="Stability AI Stable Diffusion 3.5 Image", display_name="Stability AI Stable Diffusion 3.5 Image",
category="image/partner/Stability AI", category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""), description=cleandoc(cls.__doc__ or ""),
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -354,7 +354,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="StabilityUpscaleConservativeNode", node_id="StabilityUpscaleConservativeNode",
display_name="Stability AI Upscale Conservative", display_name="Stability AI Upscale Conservative",
category="image/partner/Stability AI", category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""), description=cleandoc(cls.__doc__ or ""),
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -457,7 +457,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="StabilityUpscaleCreativeNode", node_id="StabilityUpscaleCreativeNode",
display_name="Stability AI Upscale Creative", display_name="Stability AI Upscale Creative",
category="image/partner/Stability AI", category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""), description=cleandoc(cls.__doc__ or ""),
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -578,7 +578,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="StabilityUpscaleFastNode", node_id="StabilityUpscaleFastNode",
display_name="Stability AI Upscale Fast", display_name="Stability AI Upscale Fast",
category="image/partner/Stability AI", category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""), description=cleandoc(cls.__doc__ or ""),
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -630,7 +630,7 @@ class StabilityTextToAudio(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="StabilityTextToAudio", node_id="StabilityTextToAudio",
display_name="Stability AI Text To Audio", display_name="Stability AI Text To Audio",
category="audio/partner/Stability AI", category="partner/audio/Stability AI",
essentials_category="Audio", essentials_category="Audio",
description=cleandoc(cls.__doc__ or ""), description=cleandoc(cls.__doc__ or ""),
inputs=[ inputs=[
@ -708,7 +708,7 @@ class StabilityAudioToAudio(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="StabilityAudioToAudio", node_id="StabilityAudioToAudio",
display_name="Stability AI Audio To Audio", display_name="Stability AI Audio To Audio",
category="audio/partner/Stability AI", category="partner/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""), description=cleandoc(cls.__doc__ or ""),
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -802,7 +802,7 @@ class StabilityAudioInpaint(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="StabilityAudioInpaint", node_id="StabilityAudioInpaint",
display_name="Stability AI Audio Inpaint", display_name="Stability AI Audio Inpaint",
category="audio/partner/Stability AI", category="partner/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""), description=cleandoc(cls.__doc__ or ""),
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(

View File

@ -52,7 +52,7 @@ class TopazImageEnhance(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TopazImageEnhance", node_id="TopazImageEnhance",
display_name="Topaz Image Enhance", display_name="Topaz Image Enhance",
category="image/partner/Topaz", category="partner/image/Topaz",
description="Industry-standard upscaling and image enhancement.", description="Industry-standard upscaling and image enhancement.",
inputs=[ inputs=[
IO.Combo.Input("model", options=["Reimagine"]), IO.Combo.Input("model", options=["Reimagine"]),
@ -235,7 +235,7 @@ class TopazVideoEnhance(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TopazVideoEnhance", node_id="TopazVideoEnhance",
display_name="Topaz Video Enhance (Legacy)", display_name="Topaz Video Enhance (Legacy)",
category="video/partner/Topaz", category="partner/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.", description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[ inputs=[
IO.Video.Input("video"), IO.Video.Input("video"),
@ -475,7 +475,7 @@ class TopazVideoEnhanceV2(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TopazVideoEnhanceV2", node_id="TopazVideoEnhanceV2",
display_name="Topaz Video Enhance", display_name="Topaz Video Enhance",
category="video/partner/Topaz", category="partner/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.", description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[ inputs=[
IO.Video.Input("video"), IO.Video.Input("video"),

View File

@ -83,7 +83,7 @@ class TripoTextToModelNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TripoTextToModelNode", node_id="TripoTextToModelNode",
display_name="Tripo: Text to Model", display_name="Tripo: Text to Model",
category="3d/partner/Tripo", category="partner/3d/Tripo",
inputs=[ inputs=[
IO.String.Input("prompt", multiline=True), IO.String.Input("prompt", multiline=True),
IO.String.Input("negative_prompt", multiline=True, optional=True), IO.String.Input("negative_prompt", multiline=True, optional=True),
@ -210,7 +210,7 @@ class TripoImageToModelNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TripoImageToModelNode", node_id="TripoImageToModelNode",
display_name="Tripo: Image to Model", display_name="Tripo: Image to Model",
category="3d/partner/Tripo", category="partner/3d/Tripo",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.Combo.Input( IO.Combo.Input(
@ -358,7 +358,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TripoMultiviewToModelNode", node_id="TripoMultiviewToModelNode",
display_name="Tripo: Multiview to Model", display_name="Tripo: Multiview to Model",
category="3d/partner/Tripo", category="partner/3d/Tripo",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.Image.Input("image_left", optional=True), IO.Image.Input("image_left", optional=True),
@ -518,7 +518,7 @@ class TripoTextureNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TripoTextureNode", node_id="TripoTextureNode",
display_name="Tripo: Texture model", display_name="Tripo: Texture model",
category="3d/partner/Tripo", category="partner/3d/Tripo",
inputs=[ inputs=[
IO.Custom("MODEL_TASK_ID").Input("model_task_id"), IO.Custom("MODEL_TASK_ID").Input("model_task_id"),
IO.Boolean.Input("texture", default=True, optional=True), IO.Boolean.Input("texture", default=True, optional=True),
@ -595,7 +595,7 @@ class TripoRefineNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TripoRefineNode", node_id="TripoRefineNode",
display_name="Tripo: Refine Draft model", display_name="Tripo: Refine Draft model",
category="3d/partner/Tripo", category="partner/3d/Tripo",
description="Refine a draft model created by v1.4 Tripo models only.", description="Refine a draft model created by v1.4 Tripo models only.",
inputs=[ inputs=[
IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"), IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"),
@ -635,7 +635,7 @@ class TripoRigNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TripoRigNode", node_id="TripoRigNode",
display_name="Tripo: Rig model", display_name="Tripo: Rig model",
category="3d/partner/Tripo", category="partner/3d/Tripo",
inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")], inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")],
outputs=[ outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only IO.String.Output(display_name="model_file"), # for backward compatibility only
@ -672,7 +672,7 @@ class TripoRetargetNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TripoRetargetNode", node_id="TripoRetargetNode",
display_name="Tripo: Retarget rigged model", display_name="Tripo: Retarget rigged model",
category="3d/partner/Tripo", category="partner/3d/Tripo",
inputs=[ inputs=[
IO.Custom("RIG_TASK_ID").Input("original_model_task_id"), IO.Custom("RIG_TASK_ID").Input("original_model_task_id"),
IO.Combo.Input( IO.Combo.Input(
@ -737,7 +737,7 @@ class TripoConversionNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TripoConversionNode", node_id="TripoConversionNode",
display_name="Tripo: Convert model", display_name="Tripo: Convert model",
category="3d/partner/Tripo", category="partner/3d/Tripo",
inputs=[ inputs=[
IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"), IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"),
IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]), IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]),
@ -1051,7 +1051,7 @@ class TripoP1TextToModelNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TripoP1TextToModelNode", node_id="TripoP1TextToModelNode",
display_name="Tripo P1: Text to Model", display_name="Tripo P1: Text to Model",
category="3d/partner/Tripo", category="partner/3d/Tripo",
description="Tripo P1 text-to-3D. Optimized for low-poly, game-ready meshes with stable topology.", description="Tripo P1 text-to-3D. Optimized for low-poly, game-ready meshes with stable topology.",
inputs=[ inputs=[
IO.String.Input("prompt", multiline=True, tooltip="Up to 1024 characters."), IO.String.Input("prompt", multiline=True, tooltip="Up to 1024 characters."),
@ -1122,7 +1122,7 @@ class TripoP1ImageToModelNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TripoP1ImageToModelNode", node_id="TripoP1ImageToModelNode",
display_name="Tripo P1: Image to Model", display_name="Tripo P1: Image to Model",
category="3d/partner/Tripo", category="partner/3d/Tripo",
description="Tripo P1 image-to-3D. Optimized for low-poly, game-ready meshes.", description="Tripo P1 image-to-3D. Optimized for low-poly, game-ready meshes.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -1202,7 +1202,7 @@ class TripoP1MultiviewToModelNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="TripoP1MultiviewToModelNode", node_id="TripoP1MultiviewToModelNode",
display_name="Tripo P1: Multiview to Model", display_name="Tripo P1: Multiview to Model",
category="3d/partner/Tripo", category="partner/3d/Tripo",
description="Tripo P1 multiview-to-3D from 2-4 reference images in [front, left, back, right] order. " description="Tripo P1 multiview-to-3D from 2-4 reference images in [front, left, back, right] order. "
"Front is required; any combination of the other three may be omitted.", "Front is required; any combination of the other three may be omitted.",
inputs=[ inputs=[

View File

@ -45,7 +45,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="VeoVideoGenerationNode", node_id="VeoVideoGenerationNode",
display_name="Google Veo 2 Video Generation", display_name="Google Veo 2 Video Generation",
category="video/partner/Veo", category="partner/video/Veo",
description="Generates videos from text prompts using Google's Veo 2 API", description="Generates videos from text prompts using Google's Veo 2 API",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -256,7 +256,7 @@ class Veo3VideoGenerationNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Veo3VideoGenerationNode", node_id="Veo3VideoGenerationNode",
display_name="Google Veo 3 Video Generation", display_name="Google Veo 3 Video Generation",
category="video/partner/Veo", category="partner/video/Veo",
description="Generates videos from text prompts using Google's Veo 3 API", description="Generates videos from text prompts using Google's Veo 3 API",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
@ -468,7 +468,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Veo3FirstLastFrameNode", node_id="Veo3FirstLastFrameNode",
display_name="Google Veo 3 First-Last-Frame to Video", display_name="Google Veo 3 First-Last-Frame to Video",
category="video/partner/Veo", category="partner/video/Veo",
description="Generate video using prompt and first and last frames.", description="Generate video using prompt and first and last frames.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(

View File

@ -71,7 +71,7 @@ class ViduTextToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ViduTextToVideoNode", node_id="ViduTextToVideoNode",
display_name="Vidu Text To Video Generation", display_name="Vidu Text To Video Generation",
category="video/partner/Vidu", category="partner/video/Vidu",
description="Generate video from a text prompt", description="Generate video from a text prompt",
inputs=[ inputs=[
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
@ -169,7 +169,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ViduImageToVideoNode", node_id="ViduImageToVideoNode",
display_name="Vidu Image To Video Generation", display_name="Vidu Image To Video Generation",
category="video/partner/Vidu", category="partner/video/Vidu",
description="Generate video from image and optional prompt", description="Generate video from image and optional prompt",
inputs=[ inputs=[
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
@ -273,7 +273,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ViduReferenceVideoNode", node_id="ViduReferenceVideoNode",
display_name="Vidu Reference To Video Generation", display_name="Vidu Reference To Video Generation",
category="video/partner/Vidu", category="partner/video/Vidu",
description="Generate video from multiple images and a prompt", description="Generate video from multiple images and a prompt",
inputs=[ inputs=[
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
@ -388,7 +388,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ViduStartEndToVideoNode", node_id="ViduStartEndToVideoNode",
display_name="Vidu Start End To Video Generation", display_name="Vidu Start End To Video Generation",
category="video/partner/Vidu", category="partner/video/Vidu",
description="Generate a video from start and end frames and a prompt", description="Generate a video from start and end frames and a prompt",
inputs=[ inputs=[
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
@ -492,7 +492,7 @@ class Vidu2TextToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Vidu2TextToVideoNode", node_id="Vidu2TextToVideoNode",
display_name="Vidu2 Text-to-Video Generation", display_name="Vidu2 Text-to-Video Generation",
category="video/partner/Vidu", category="partner/video/Vidu",
description="Generate video from a text prompt", description="Generate video from a text prompt",
inputs=[ inputs=[
IO.Combo.Input("model", options=["viduq2"]), IO.Combo.Input("model", options=["viduq2"]),
@ -584,7 +584,7 @@ class Vidu2ImageToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Vidu2ImageToVideoNode", node_id="Vidu2ImageToVideoNode",
display_name="Vidu2 Image-to-Video Generation", display_name="Vidu2 Image-to-Video Generation",
category="video/partner/Vidu", category="partner/video/Vidu",
description="Generate a video from an image and an optional prompt.", description="Generate a video from an image and an optional prompt.",
inputs=[ inputs=[
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
@ -714,7 +714,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Vidu2ReferenceVideoNode", node_id="Vidu2ReferenceVideoNode",
display_name="Vidu2 Reference-to-Video Generation", display_name="Vidu2 Reference-to-Video Generation",
category="video/partner/Vidu", category="partner/video/Vidu",
description="Generate a video from multiple reference images and a prompt.", description="Generate a video from multiple reference images and a prompt.",
inputs=[ inputs=[
IO.Combo.Input("model", options=["viduq2"]), IO.Combo.Input("model", options=["viduq2"]),
@ -849,7 +849,7 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Vidu2StartEndToVideoNode", node_id="Vidu2StartEndToVideoNode",
display_name="Vidu2 Start/End Frame-to-Video Generation", display_name="Vidu2 Start/End Frame-to-Video Generation",
category="video/partner/Vidu", category="partner/video/Vidu",
description="Generate a video from a start frame, an end frame, and a prompt.", description="Generate a video from a start frame, an end frame, and a prompt.",
inputs=[ inputs=[
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
@ -969,7 +969,7 @@ class ViduExtendVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ViduExtendVideoNode", node_id="ViduExtendVideoNode",
display_name="Vidu Video Extension", display_name="Vidu Video Extension",
category="video/partner/Vidu", category="partner/video/Vidu",
description="Extend an existing video by generating additional frames.", description="Extend an existing video by generating additional frames.",
inputs=[ inputs=[
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
@ -1138,7 +1138,7 @@ class ViduMultiFrameVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ViduMultiFrameVideoNode", node_id="ViduMultiFrameVideoNode",
display_name="Vidu Multi-Frame Video Generation", display_name="Vidu Multi-Frame Video Generation",
category="video/partner/Vidu", category="partner/video/Vidu",
description="Generate a video with multiple keyframe transitions.", description="Generate a video with multiple keyframe transitions.",
inputs=[ inputs=[
IO.Combo.Input("model", options=["viduq2-pro", "viduq2-turbo"]), IO.Combo.Input("model", options=["viduq2-pro", "viduq2-turbo"]),
@ -1284,7 +1284,7 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Vidu3TextToVideoNode", node_id="Vidu3TextToVideoNode",
display_name="Vidu Q3 Text-to-Video Generation", display_name="Vidu Q3 Text-to-Video Generation",
category="video/partner/Vidu", category="partner/video/Vidu",
description="Generate video from a text prompt.", description="Generate video from a text prompt.",
inputs=[ inputs=[
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
@ -1429,7 +1429,7 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Vidu3ImageToVideoNode", node_id="Vidu3ImageToVideoNode",
display_name="Vidu Q3 Image-to-Video Generation", display_name="Vidu Q3 Image-to-Video Generation",
category="video/partner/Vidu", category="partner/video/Vidu",
description="Generate a video from an image and an optional prompt.", description="Generate a video from an image and an optional prompt.",
inputs=[ inputs=[
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
@ -1571,7 +1571,7 @@ class Vidu3StartEndToVideoNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Vidu3StartEndToVideoNode", node_id="Vidu3StartEndToVideoNode",
display_name="Vidu Q3 Start/End Frame-to-Video Generation", display_name="Vidu Q3 Start/End Frame-to-Video Generation",
category="video/partner/Vidu", category="partner/video/Vidu",
description="Generate a video from a start frame, an end frame, and a prompt.", description="Generate a video from a start frame, an end frame, and a prompt.",
inputs=[ inputs=[
IO.DynamicCombo.Input( IO.DynamicCombo.Input(

View File

@ -61,7 +61,7 @@ class WanTextToImageApi(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="WanTextToImageApi", node_id="WanTextToImageApi",
display_name="Wan Text to Image", display_name="Wan Text to Image",
category="image/partner/Wan", category="partner/image/Wan",
description="Generates an image based on a text prompt.", description="Generates an image based on a text prompt.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -184,7 +184,7 @@ class WanImageToImageApi(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="WanImageToImageApi", node_id="WanImageToImageApi",
display_name="Wan Image to Image", display_name="Wan Image to Image",
category="image/partner/Wan", category="partner/image/Wan",
description="Generates an image from one or two input images and a text prompt. " description="Generates an image from one or two input images and a text prompt. "
"The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).", "The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).",
inputs=[ inputs=[
@ -312,7 +312,7 @@ class WanTextToVideoApi(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="WanTextToVideoApi", node_id="WanTextToVideoApi",
display_name="Wan Text to Video", display_name="Wan Text to Video",
category="video/partner/Wan", category="partner/video/Wan",
description="Generates a video based on a text prompt.", description="Generates a video based on a text prompt.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -495,7 +495,7 @@ class WanImageToVideoApi(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="WanImageToVideoApi", node_id="WanImageToVideoApi",
display_name="Wan Image to Video", display_name="Wan Image to Video",
category="video/partner/Wan", category="partner/video/Wan",
description="Generates a video from the first frame and a text prompt.", description="Generates a video from the first frame and a text prompt.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -674,7 +674,7 @@ class WanReferenceVideoApi(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="WanReferenceVideoApi", node_id="WanReferenceVideoApi",
display_name="Wan Reference to Video", display_name="Wan Reference to Video",
category="video/partner/Wan", category="partner/video/Wan",
description="Use the character and voice from input videos, combined with a prompt, " description="Use the character and voice from input videos, combined with a prompt, "
"to generate a new video that maintains character consistency.", "to generate a new video that maintains character consistency.",
inputs=[ inputs=[
@ -828,7 +828,7 @@ class Wan2TextToVideoApi(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Wan2TextToVideoApi", node_id="Wan2TextToVideoApi",
display_name="Wan 2.7 Text to Video", display_name="Wan 2.7 Text to Video",
category="video/partner/Wan", category="partner/video/Wan",
description="Generates a video based on a text prompt using the Wan 2.7 model.", description="Generates a video based on a text prompt using the Wan 2.7 model.",
inputs=[ inputs=[
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
@ -981,7 +981,7 @@ class Wan2ImageToVideoApi(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Wan2ImageToVideoApi", node_id="Wan2ImageToVideoApi",
display_name="Wan 2.7 Image to Video", display_name="Wan 2.7 Image to Video",
category="video/partner/Wan", category="partner/video/Wan",
description="Generate a video from a first-frame image, with optional last-frame image and audio.", description="Generate a video from a first-frame image, with optional last-frame image and audio.",
inputs=[ inputs=[
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
@ -1152,7 +1152,7 @@ class Wan2VideoContinuationApi(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Wan2VideoContinuationApi", node_id="Wan2VideoContinuationApi",
display_name="Wan 2.7 Video Continuation", display_name="Wan 2.7 Video Continuation",
category="video/partner/Wan", category="partner/video/Wan",
description="Continue a video from where it left off, with optional last-frame control.", description="Continue a video from where it left off, with optional last-frame control.",
inputs=[ inputs=[
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
@ -1319,7 +1319,7 @@ class Wan2VideoEditApi(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Wan2VideoEditApi", node_id="Wan2VideoEditApi",
display_name="Wan 2.7 Video Edit", display_name="Wan 2.7 Video Edit",
category="video/partner/Wan", category="partner/video/Wan",
description="Edit a video using text instructions, reference images, or style transfer.", description="Edit a video using text instructions, reference images, or style transfer.",
inputs=[ inputs=[
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
@ -1477,7 +1477,7 @@ class Wan2ReferenceVideoApi(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="Wan2ReferenceVideoApi", node_id="Wan2ReferenceVideoApi",
display_name="Wan 2.7 Reference to Video", display_name="Wan 2.7 Reference to Video",
category="video/partner/Wan", category="partner/video/Wan",
description="Generate a video featuring a person or object from reference materials. " description="Generate a video featuring a person or object from reference materials. "
"Supports single-character performances and multi-character interactions.", "Supports single-character performances and multi-character interactions.",
inputs=[ inputs=[
@ -1651,7 +1651,7 @@ class HappyHorseTextToVideoApi(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="HappyHorseTextToVideoApi", node_id="HappyHorseTextToVideoApi",
display_name="HappyHorse Text to Video", display_name="HappyHorse Text to Video",
category="video/partner/Wan", category="partner/video/Wan",
description="Generates a video based on a text prompt using the HappyHorse model.", description="Generates a video based on a text prompt using the HappyHorse model.",
inputs=[ inputs=[
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
@ -1775,7 +1775,7 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="HappyHorseImageToVideoApi", node_id="HappyHorseImageToVideoApi",
display_name="HappyHorse Image to Video", display_name="HappyHorse Image to Video",
category="video/partner/Wan", category="partner/video/Wan",
description="Generate a video from a first-frame image using the HappyHorse model.", description="Generate a video from a first-frame image using the HappyHorse model.",
inputs=[ inputs=[
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
@ -1905,7 +1905,7 @@ class HappyHorseVideoEditApi(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="HappyHorseVideoEditApi", node_id="HappyHorseVideoEditApi",
display_name="HappyHorse Video Edit", display_name="HappyHorse Video Edit",
category="video/partner/Wan", category="partner/video/Wan",
description="Edit a video using text instructions or reference images with the HappyHorse model. " description="Edit a video using text instructions or reference images with the HappyHorse model. "
"Output duration is 3-15s and matches the input video; inputs longer than 15s are truncated.", "Output duration is 3-15s and matches the input video; inputs longer than 15s are truncated.",
inputs=[ inputs=[
@ -2046,7 +2046,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="HappyHorseReferenceVideoApi", node_id="HappyHorseReferenceVideoApi",
display_name="HappyHorse Reference to Video", display_name="HappyHorse Reference to Video",
category="video/partner/Wan", category="partner/video/Wan",
description="Generate a video featuring a person or object from reference materials with the HappyHorse " description="Generate a video featuring a person or object from reference materials with the HappyHorse "
"model. Supports single-character performances and multi-character interactions.", "model. Supports single-character performances and multi-character interactions.",
inputs=[ inputs=[

View File

@ -27,7 +27,7 @@ class WavespeedFlashVSRNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="WavespeedFlashVSRNode", node_id="WavespeedFlashVSRNode",
display_name="FlashVSR Video Upscale", display_name="FlashVSR Video Upscale",
category="video/partner/WaveSpeed", category="partner/video/WaveSpeed",
description="Fast, high-quality video upscaler that " description="Fast, high-quality video upscaler that "
"boosts resolution and restores clarity for low-resolution or blurry footage.", "boosts resolution and restores clarity for low-resolution or blurry footage.",
inputs=[ inputs=[
@ -98,7 +98,7 @@ class WavespeedImageUpscaleNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="WavespeedImageUpscaleNode", node_id="WavespeedImageUpscaleNode",
display_name="WaveSpeed Image Upscale", display_name="WaveSpeed Image Upscale",
category="image/partner/WaveSpeed", category="partner/image/WaveSpeed",
description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.", description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.",
inputs=[ inputs=[
IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]), IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]),

View File

@ -469,6 +469,11 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input
input_container = None input_container = None
output_container = None output_container = None
# get_stream_source() is untrimmed, so apply the trim window in this same pass.
# start_time is normalized (>= 0); duration == 0 means "until the end".
start_time, duration = video.get_active_trim_window()
trimming = bool(start_time or duration)
try: try:
input_source = video.get_stream_source() input_source = video.get_stream_source()
input_container = av.open(input_source, mode="r") input_container = av.open(input_source, mode="r")
@ -487,16 +492,45 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input
audio_stream.layout = stream.layout audio_stream.layout = stream.layout
break break
in_video = input_container.streams.video[0]
start_pts = int(start_time / in_video.time_base) if trimming else 0
end_pts = int((start_time + duration) / in_video.time_base) if duration else None
if start_pts:
input_container.seek(start_pts, stream=in_video)
encoded = 0
for frame in input_container.decode(video=0): for frame in input_container.decode(video=0):
if trimming:
if frame.pts is None or frame.pts < start_pts:
continue
if end_pts is not None and frame.pts >= end_pts:
break
frame = frame.reformat(width=out_w, height=out_h, format="yuv420p") frame = frame.reformat(width=out_w, height=out_h, format="yuv420p")
# Re-wrap as a fresh frame: dropping irregular source timestamps (VFR/AVI/GIF/...)
# lets the encoder assign clean ones and avoids mp4 muxer errors.
frame = av.VideoFrame.from_ndarray(frame.to_ndarray(format="yuv420p"), format="yuv420p")
for packet in video_stream.encode(frame): for packet in video_stream.encode(frame):
output_container.mux(packet) output_container.mux(packet)
encoded += 1
for packet in video_stream.encode(): for packet in video_stream.encode():
output_container.mux(packet) output_container.mux(packet)
if encoded == 0:
raise ValueError(
f"resize produced no frames (start_time={start_time}, duration={duration} "
"selected nothing from the source)"
)
if audio_stream is not None: if audio_stream is not None:
input_container.seek(0) input_container.seek(0)
for audio_frame in input_container.decode(audio=0): for audio_frame in input_container.decode(audio=0):
if trimming:
if audio_frame.time is None or audio_frame.time < start_time:
continue
if duration and audio_frame.time > start_time + duration:
break
# Carry odd audio time bases the mp4 muxer rejects; reset pts, encoder assigns clean ones (MP3-in-AVI)
audio_frame.pts = None
for packet in audio_stream.encode(audio_frame): for packet in audio_stream.encode(audio_frame):
output_container.mux(packet) output_container.mux(packet)
for packet in audio_stream.encode(): for packet in audio_stream.encode():

View File

@ -65,6 +65,12 @@ class ChromaRadianceOptions(io.ComfyNode):
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).", tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
advanced=True, advanced=True,
), ),
io.Boolean.Input(
id="force_sequential_txt_ids",
default=False,
tooltip="Force usage of sequential text token IDs instead of zeroes. Should be used for checkpoints from 2026-05-22 to 2026-06-01 that are trained in this way but do not contain the __sequential__ key in the state dict.",
advanced=True,
),
], ],
outputs=[io.Model.Output()], outputs=[io.Model.Output()],
) )
@ -78,11 +84,15 @@ class ChromaRadianceOptions(io.ComfyNode):
start_sigma: float, start_sigma: float,
end_sigma: float, end_sigma: float,
nerf_tile_size: int, nerf_tile_size: int,
force_sequential_txt_ids: bool,
) -> io.NodeOutput: ) -> io.NodeOutput:
radiance_options = {} radiance_options = {}
if nerf_tile_size >= 0: if nerf_tile_size >= 0:
radiance_options["nerf_tile_size"] = nerf_tile_size radiance_options["nerf_tile_size"] = nerf_tile_size
if force_sequential_txt_ids:
radiance_options["use_sequential_txt_ids"] = True
if not radiance_options: if not radiance_options:
return io.NodeOutput(model) return io.NodeOutput(model)

File diff suppressed because it is too large Load Diff

View File

@ -102,11 +102,18 @@ class MathExpressionNode(io.ComfyNode):
f"Math Expression '{expression}' must evaluate to a numeric result, " f"Math Expression '{expression}' must evaluate to a numeric result, "
f"got {type(result).__name__}: {result!r}" f"got {type(result).__name__}: {result!r}"
) )
if not math.isfinite(result): try:
float_result = float(result)
except OverflowError:
raise ValueError(
f"Math Expression '{expression}' produced a result too large to "
f"represent as a float: {result}"
) from None
if not math.isfinite(float_result):
raise ValueError( raise ValueError(
f"Math Expression '{expression}' produced a non-finite result: {result}" f"Math Expression '{expression}' produced a non-finite result: {result}"
) )
return io.NodeOutput(float(result), int(result), bool(result)) return io.NodeOutput(float_result, int(result), bool(result))
class MathExtension(ComfyExtension): class MathExtension(ComfyExtension):

View File

@ -16,7 +16,7 @@ from comfy.cli_args import args
from comfy_api.latest import ComfyExtension, IO, Types from comfy_api.latest import ComfyExtension, IO, Types
def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None): def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None, unlit=False):
# Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors, # Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors,
# stashing per-item lengths as runtime attrs so consumers can recover the real slice. # stashing per-item lengths as runtime attrs so consumers can recover the real slice.
# colors and uvs are 1:1 with vertices, so they're padded to max_vertices and read with vertex_counts. # colors and uvs are 1:1 with vertices, so they're padded to max_vertices and read with vertex_counts.
@ -54,7 +54,7 @@ def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=Non
return Types.MESH(packed_vertices, packed_faces, return Types.MESH(packed_vertices, packed_faces,
uvs=packed_uvs, vertex_colors=packed_colors, texture=texture, uvs=packed_uvs, vertex_colors=packed_colors, texture=texture,
vertex_counts=vertex_counts, face_counts=face_counts) vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit)
def get_mesh_batch_item(mesh, index): def get_mesh_batch_item(mesh, index):
@ -77,7 +77,7 @@ def get_mesh_batch_item(mesh, index):
def save_glb(vertices, faces, filepath, metadata=None, def save_glb(vertices, faces, filepath, metadata=None,
uvs=None, vertex_colors=None, texture_image=None): uvs=None, vertex_colors=None, texture_image=None, unlit=False):
""" """
Save PyTorch tensor vertices and faces as a GLB file without external dependencies. Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
@ -234,6 +234,17 @@ def save_glb(vertices, faces, filepath, metadata=None,
textures = [] textures = []
samplers = [] samplers = []
materials = [] materials = []
extensions_used = []
if unlit and texture_png_bytes is None:
# Flat, light-independent shading (KHR_materials_unlit): COLOR_0 is shown as-is, matching how a
# gaussian splat renders (emissive). Without this the viewer lights the mesh and washes the colours.
materials.append({
"pbrMetallicRoughness": {"baseColorFactor": [1.0, 1.0, 1.0, 1.0], "metallicFactor": 0.0, "roughnessFactor": 1.0},
"extensions": {"KHR_materials_unlit": {}},
"doubleSided": True,
})
extensions_used.append("KHR_materials_unlit")
primitive["material"] = 0
if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes: if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes:
buffer_views.append({ buffer_views.append({
"buffer": 0, "buffer": 0,
@ -271,6 +282,8 @@ def save_glb(vertices, faces, filepath, metadata=None,
gltf["textures"] = textures gltf["textures"] = textures
if materials: if materials:
gltf["materials"] = materials gltf["materials"] = materials
if extensions_used:
gltf["extensionsUsed"] = extensions_used
if metadata: if metadata:
gltf["asset"]["extras"] = metadata gltf["asset"]["extras"] = metadata
@ -376,7 +389,8 @@ class SaveGLB(IO.ComfyNode):
save_glb(vertices_i, faces_i, os.path.join(full_output_folder, f), metadata, save_glb(vertices_i, faces_i, os.path.join(full_output_folder, f), metadata,
uvs=uvs_i, uvs=uvs_i,
vertex_colors=v_colors, vertex_colors=v_colors,
texture_image=tex_img) texture_image=tex_img,
unlit=getattr(mesh, "unlit", False))
results.append({ results.append({
"filename": f, "filename": f,
"subfolder": subfolder, "subfolder": subfolder,

View File

@ -0,0 +1,270 @@
# TripoSplat nodes: image -> 3D gaussian splat
import logging
import torch
import torch.nn.functional as F
from typing_extensions import override
import comfy.model_management
import comfy.nested_tensor
import comfy.patcher_extension
import comfy.utils
from comfy_api.latest import ComfyExtension, IO, Types
_Q_TOKEN_LENGTH = 8192
_LATENT_CHANNELS = 16
_CAM_CHANNELS = 5
_DINOV3_MEAN = [0.485, 0.456, 0.406]
_DINOV3_STD = [0.229, 0.224, 0.225]
_NUM_GAUSSIANS_MIN = 32768
_NUM_GAUSSIANS_MAX = 1048576
def _preprocess(image: torch.Tensor, mask: torch.Tensor, erode_radius: int, size: int) -> torch.Tensor:
# Match original preprocessing:
# resize min side to `size` -> erode alpha -> alpha bbox -> 1.2x square crop -> resize -> composite on black.
rgb = image[..., :3].clamp(0, 1).movedim(-1, 0) # (3, H, W)
alpha = mask.clamp(0, 1)[None] # (1, H, W)
rgba = torch.cat([rgb, alpha], 0)[None] # (1, 4, H, W)
h, w = rgba.shape[-2:]
s = size / min(w, h)
rgba = comfy.utils.common_upscale(rgba, max(1, round(w * s)), max(1, round(h * s)), "lanczos", "disabled").clamp(0, 1)
a = rgba[:, 3:4]
if erode_radius > 0:
# min filter over a (2r+1) window == morphological erosion of the alpha matte.
a = -F.max_pool2d(-a, 2 * erode_radius + 1, stride=1, padding=erode_radius)
rgba = torch.cat([rgba[:, :3], a], 1)
ys, xs = torch.nonzero(a[0, 0] > 0, as_tuple=True)
if xs.numel() == 0:
raise ValueError("TripoSplatPreprocessImage: mask is empty (no foreground pixels).")
x0, x1 = int(xs.min()), int(xs.max())
y0, y1 = int(ys.min()), int(ys.max())
cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
half = max(x1 - x0, y1 - y0) / 2 * 1.2
left, upper, right, lower = int(cx - half), int(cy - half), int(cx + half), int(cy + half)
H, W = rgba.shape[-2:]
crop = rgba.new_zeros((1, 4, lower - upper, right - left)) # out-of-bounds stays 0, matching PIL.crop
sx0, sy0, sx1, sy1 = max(left, 0), max(upper, 0), min(right, W), min(lower, H)
if sx1 > sx0 and sy1 > sy0:
crop[:, :, sy0 - upper:sy1 - upper, sx0 - left:sx1 - left] = rgba[:, :, sy0:sy1, sx0:sx1]
crop = comfy.utils.common_upscale(crop, size, size, "lanczos", "disabled").clamp(0, 1)
out = (crop[:, :3] * crop[:, 3:4])[0].movedim(0, -1) # composite over black == rgb * alpha
return out.unsqueeze(0) # (1, 1024, 1024, 3)
class TripoSplatPreprocessImage(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoSplatPreprocessImage",
display_name="TripoSplat Preprocess Image",
category="3d/conditioning",
description="Crop center each image to a square canvas on a black background and add padding.",
inputs=[
IO.Image.Input("image"),
IO.Mask.Input("mask"),
IO.Int.Input("erode_radius", default=1, min=0, max=16,
tooltip="Erode the alpha matte by this pixel radius before cropping (avoids border bleed)."),
IO.Int.Input("size", default=1024, min=256, max=4096, step=16,
tooltip="Square image size. The model is trained at 1024; other sizes run but are off-distribution."),
],
outputs=[IO.Image.Output(display_name="image")],
)
@classmethod
def execute(cls, image, mask, erode_radius, size) -> IO.NodeOutput:
size = max(16, (int(size) // 16) * 16) # DINOv3 patch / Flux2 VAE stride is 16
if mask.shape[0] != image.shape[0]:
mask = comfy.utils.repeat_to_batch_size(mask, image.shape[0])
if tuple(mask.shape[1:]) != tuple(image.shape[1:3]):
mask = F.interpolate(mask[:, None].float(), size=tuple(image.shape[1:3]), mode="bilinear", align_corners=False)[:, 0]
prepared = torch.cat([_preprocess(image[i], mask[i], erode_radius, size) for i in range(image.shape[0])], dim=0)
return IO.NodeOutput(prepared)
class TripoSplatConditioning(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoSplatConditioning",
display_name="TripoSplat Conditioning",
category="3d/conditioning",
description="Encode the image with DINOv3 and the Flux2 VAE into TripoSplat positive/negative "
"conditioning, and create the fixed size noise target (latent + camera) for the KSampler",
inputs=[
IO.ClipVision.Input("clip_vision", tooltip="DINOv3 ViT-H/16+ image encoder"),
IO.Vae.Input("vae", tooltip="Flux2 VAE"),
IO.Image.Input("image"),
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
IO.Conditioning.Output(display_name="negative"),
IO.Latent.Output(display_name="latent", tooltip="The fixed size noise target (latent +camera)."),
],
)
@classmethod
def execute(cls, clip_vision, vae, image) -> IO.NodeOutput:
# feature1: DINOv3 token sequence (cls + registers + patches), ImageNet-normalized, with a final non-affine layer norm on top
comfy.model_management.load_model_gpu(clip_vision.patcher)
device = clip_vision.load_device
img = image.movedim(-1, 1).to(device) # (B,3,H,W) in [0,1]
mean = torch.tensor(_DINOV3_MEAN, device=device).view(1, 3, 1, 1)
std = torch.tensor(_DINOV3_STD, device=device).view(1, 3, 1, 1)
img = (img - mean) / std
seq = clip_vision.model(pixel_values=img.float())[0]
feature1 = F.layer_norm(seq.float(), seq.shape[-1:]).to(comfy.model_management.intermediate_device())
# Second conditioning: the Flux2 VAE latent of the image, carried as a standard reference_latents entry
ref = vae.encode(image).to(comfy.model_management.intermediate_device()) # (B, 128, H, W)
b = ref.shape[0]
positive = [[feature1, {"reference_latents": [ref]}]]
negative = [[torch.zeros_like(feature1), {"reference_latents": [torch.zeros_like(ref)]}]]
# Fixed noise target: the latent is a constant-shape (8192, 16) shape-code + a (1, 5) camera token
dev = comfy.model_management.intermediate_device()
latent_seq = torch.zeros([b, _Q_TOKEN_LENGTH, _LATENT_CHANNELS], device=dev)
camera = torch.zeros([b, 1, _CAM_CHANNELS], device=dev)
samples = comfy.nested_tensor.NestedTensor((latent_seq, camera))
return IO.NodeOutput(positive, negative, {"samples": samples})
class VAEDecodeTripoSplat(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VAEDecodeTripoSplat",
display_name="TripoSplat Decode",
category="3d/latent",
description="Decode the sampled TripoSplat latent into a 3D gaussian splat. "
"Modify the number of gaussians to vary the density.",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"),
IO.Int.Input("num_gaussians", default=262144, min=_NUM_GAUSSIANS_MIN, max=_NUM_GAUSSIANS_MAX, step=32,
tooltip="Number of gaussians to produce (rounded to a multiple of 32). "
"262144 matches the octree's point density; higher oversamples the same points "
"(denser, but no new detail) and costs proportionally more VRAM/time."),
IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff,
tooltip="Seeds the octree point sampler (global RNG) for deterministic decodes."),
],
outputs=[IO.Splat.Output(display_name="splat")],
)
@classmethod
def execute(cls, samples, vae, num_gaussians, seed) -> IO.NodeOutput:
s = samples["samples"]
latent = s.unbind()[0] if getattr(s, "is_nested", False) else s # take the latent stream, drop camera
decoder = vae.first_stage_model
gpp = decoder.gaussians_per_point
n = max(_NUM_GAUSSIANS_MIN, min(_NUM_GAUSSIANS_MAX, int(num_gaussians)))
if n % gpp != 0:
n = round(n / gpp) * gpp
dtype_size = comfy.model_management.dtype_size(vae.vae_dtype)
hidden = decoder.gs.model_channels
cond_tokens = latent.shape[1]
memory_required = (cond_tokens * 4 + (n // gpp) * 10) * hidden * dtype_size
comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required)
latent = latent.to(device=vae.device, dtype=vae.vae_dtype)
generator = torch.Generator(device="cpu").manual_seed(seed)
parts = [g.render_tensors() for g in decoder.decode(latent, num_gaussians=n, generator=generator)]
positions, scales, rotations, opacities, sh = (torch.stack(t) for t in zip(*parts))
return IO.NodeOutput(Types.SPLAT(positions, scales, rotations, opacities, sh))
class TripoSplatSamplingPreview(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoSplatSamplingPreview",
display_name="TripoSplat Sampling Preview",
category="3d/latent",
description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded "
"gaussian splat preview at each step.",
inputs=[
IO.Model.Input("model"),
IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"),
IO.Int.Input("octree_level", default=5, min=2, max=8, advanced=True,
tooltip="Octree depth for the preview decode (lower = cheaper/coarser)."),
IO.Int.Input("num_gaussians", default=16384, min=1024, max=262144, step=32,
tooltip="Number of gaussians to produce for the preview (rounded to a multiple of 32)."),
IO.Float.Input("yaw", default=90.0, min=-360.0, max=360.0, step=1.0, tooltip="Preview camera yaw in degrees.", advanced=True,),
IO.Float.Input("pitch", default=15.0, min=-89.0, max=89.0, step=1.0, tooltip="Preview camera pitch in degrees.", advanced=True,),
IO.Int.Input("point_size", default=3, min=1, max=16,
tooltip="Maximum splat radius in pixels. Each gaussian is sized from its scale and capped here; "
"lower = finer/pointier, higher = chunkier."),
],
outputs=[IO.Model.Output()],
)
@classmethod
def execute(cls, model, vae, octree_level, num_gaussians, yaw, pitch, point_size) -> IO.NodeOutput:
from comfy.ldm.triposplat.preview import decode_x0_to_image
cfg = {"gaussians": num_gaussians, "level": octree_level, "yaw": yaw, "pitch": pitch,
"point_size": point_size}
fsm = vae.first_stage_model
cond_tokens = model.model.diffusion_model.q_token_length
memory_required = (cond_tokens * 4 + (num_gaussians // fsm.gaussians_per_point) * 10) * fsm.gs.model_channels * comfy.model_management.dtype_size(vae.vae_dtype)
# Live preview via WrappersMP.OUTER_SAMPLE + ProgressBar
# The wrapper augments the sampler's own callback to decode x0 -> gaussian splat -> preview image each step
def outer_sample_wrapper(executor, *args, **kwargs):
args = list(args)
cb_idx = 5 # outer_sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
orig_cb = args[cb_idx] if len(args) > cb_idx else kwargs.get("callback")
state = {"ok": True, "pbar": None, "loaded": False}
def callback(step, x0, x, total_steps):
if orig_cb is not None:
orig_cb(step, x0, x, total_steps)
if not state["ok"]:
return
try:
if not state["loaded"]:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
loaded_models.append(vae.patcher)
comfy.model_management.load_models_gpu(loaded_models, memory_required=memory_required)
state["loaded"] = True
img = decode_x0_to_image(vae, x0, cfg)
if state["pbar"] is None:
state["pbar"] = comfy.utils.ProgressBar(total_steps)
state["pbar"].update_absolute(step + 1, total_steps, ("JPEG", img, 512))
except Exception as e:
logging.warning("TripoSplatSamplingPreview: preview failed, disabling ({})".format(e))
state["ok"] = False
if len(args) > cb_idx:
args[cb_idx] = callback
else:
kwargs["callback"] = callback
return executor(*args, **kwargs)
m = model.clone()
m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "triposplat_sampling_preview", outer_sample_wrapper)
return IO.NodeOutput(m)
class TripoSplatExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
TripoSplatPreprocessImage,
TripoSplatConditioning,
VAEDecodeTripoSplat,
TripoSplatSamplingPreview,
]
async def comfy_entrypoint() -> TripoSplatExtension:
return TripoSplatExtension()

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.22.0" __version__ = "0.23.0"

View File

@ -464,13 +464,6 @@ def start_comfyui(asyncio_loop=None):
folder_paths.set_temp_directory(temp_dir) folder_paths.set_temp_directory(temp_dir)
cleanup_temp() cleanup_temp()
if args.windows_standalone_build:
try:
import new_updater
new_updater.update_windows_updater()
except:
pass
if not asyncio_loop: if not asyncio_loop:
asyncio_loop = asyncio.new_event_loop() asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop) asyncio.set_event_loop(asyncio_loop)

View File

@ -1,35 +0,0 @@
import os
import shutil
base_path = os.path.dirname(os.path.realpath(__file__))
def update_windows_updater():
top_path = os.path.dirname(base_path)
updater_path = os.path.join(base_path, ".ci/update_windows/update.py")
bat_path = os.path.join(base_path, ".ci/update_windows/update_comfyui.bat")
dest_updater_path = os.path.join(top_path, "update/update.py")
dest_bat_path = os.path.join(top_path, "update/update_comfyui.bat")
dest_bat_deps_path = os.path.join(top_path, "update/update_comfyui_and_python_dependencies.bat")
try:
with open(dest_bat_path, 'rb') as f:
contents = f.read()
except:
return
if not contents.startswith(b"..\\python_embeded\\python.exe .\\update.py"):
return
shutil.copy(updater_path, dest_updater_path)
try:
with open(dest_bat_deps_path, 'rb') as f:
contents = f.read()
contents = contents.replace(b'..\\python_embeded\\python.exe .\\update.py ..\\ComfyUI\\', b'call update_comfyui.bat nopause')
with open(dest_bat_deps_path, 'wb') as f:
f.write(contents)
except:
pass
shutil.copy(bat_path, dest_bat_path)
print("Updated the windows standalone package updater.") # noqa: T201

View File

@ -2455,6 +2455,8 @@ async def init_builtin_extra_nodes():
"nodes_save_3d.py", "nodes_save_3d.py",
"nodes_moge.py", "nodes_moge.py",
"nodes_mediapipe.py", "nodes_mediapipe.py",
"nodes_gaussian_splat.py",
"nodes_triposplat.py"
] ]
import_failed = [] import_failed = []

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.22.0" version = "0.23.0"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.10" requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.44.19 comfyui-frontend-package==1.44.19
comfyui-workflow-templates==0.9.91 comfyui-workflow-templates==0.9.92
comfyui-embedded-docs==0.5.2 comfyui-embedded-docs==0.5.2
torch torch
torchsde torchsde

View File

@ -197,3 +197,10 @@ class TestMathExpressionExecute:
def test_pow_huge_exponent_raises(self): def test_pow_huge_exponent_raises(self):
with pytest.raises(ValueError, match="Exponent .* exceeds maximum"): with pytest.raises(ValueError, match="Exponent .* exceeds maximum"):
self._exec("pow(a, b)", a=10, b=10000000) self._exec("pow(a, b)", a=10, b=10000000)
def test_huge_int_result_raises_value_error(self):
# Exponent is within the allowed MAX_EXPONENT range, so the result is a
# finite Python int that is nonetheless too large to convert to float.
# This must raise a clean ValueError, not an uncaught OverflowError.
with pytest.raises(ValueError, match="too large to represent as a float"):
self._exec("2 ** 3999")