Merge branch 'master' into seedvr2-native-support

This commit is contained in:
John Pollock 2026-06-02 21:58:55 -05:00
commit 6c233b1966
88 changed files with 4850 additions and 751 deletions

View File

@ -105,7 +105,7 @@ class WindowAttention(nn.Module):
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
relative_position_bias = comfy.ops.cast_to_input(relative_position_bias.permute(2, 0, 1).contiguous(), attn) # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:

View File

@ -55,12 +55,7 @@ class BackgroundRemovalModel():
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
if mask.ndim == 3:
mask = mask.unsqueeze(0)
if mask.shape[1] != 1:
mask = mask.movedim(-1, 1)
return mask
return mask.squeeze(1) # (B, 1, H, W) -> (B, H, W)
def load_background_removal_model(sd):

View File

@ -149,6 +149,7 @@ parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=Non
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
parser.add_argument("--fast-disk", action="store_true", help="Prefer disk-backed dynamic loading and offload over unpinned RAM. Can be faster for users with fast NVME disks.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")

View File

@ -9,6 +9,7 @@ import comfy.model_management
import comfy.utils
import comfy.clip_model
import comfy.image_encoders.dino2
import comfy.image_encoders.dino3
class Output:
def __getitem__(self, key):
@ -23,12 +24,16 @@ IMAGE_ENCODERS = {
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
"dinov3": comfy.image_encoders.dino3.DINOv3ViTModel,
}
class ClipVisionModel():
def __init__(self, json_config):
with open(json_config) as f:
config = json.load(f)
if isinstance(json_config, dict):
config = json_config
else:
with open(json_config) as f:
config = json.load(f)
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
@ -134,6 +139,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
elif 'layer.0.mlp.gate_proj.weight' in sd and 'layer.31.norm1.weight' in sd: # Dinov3 ViT-H/16+ (SwiGLU gated MLP, 32 layers)
json_config = comfy.image_encoders.dino3.DINOV3_VITH_CONFIG
else:
return None

View File

@ -0,0 +1,259 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
# DINOv3 ViT-H/16+ (SwiGLU)
DINOV3_VITH_CONFIG = {
"model_type": "dinov3",
"num_hidden_layers": 32,
"hidden_size": 1280,
"num_attention_heads": 20,
"num_register_tokens": 4,
"intermediate_size": 5120,
"layer_norm_eps": 1e-5,
"num_channels": 3,
"patch_size": 16,
"rope_theta": 100.0,
"use_gated_mlp": True,
"gated_mlp_act": "silu",
"image_size": 1024,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225],
}
class DINOv3ViTMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
self.act_fn = torch.nn.GELU()
def forward(self, x):
return self.down_proj(self.act_fn(self.up_proj(x)))
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, **kwargs):
num_tokens = q.shape[-2]
num_patches = sin.shape[-2]
num_prefix_tokens = num_tokens - num_patches
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
return q, k
class DINOv3ViTAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, device, dtype, operations):
super().__init__()
self.embed_dim = hidden_size
self.num_heads = num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
def forward(self, hidden_states, attention_mask=None, position_embeddings=None, **kwargs):
batch_size, patches, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
attn = optimized_attention_for_device(query_states.device, mask=False)
attn_output = attn(
query_states, key_states, value_states, self.num_heads, attention_mask,
skip_reshape=True, skip_output_reshape=True, low_precision_attention=False,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
class DINOv3ViTGatedMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations, act="silu"):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
self.act_fn = torch.nn.SiLU() if act == "silu" else torch.nn.GELU()
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def get_patches_center_coordinates(num_patches_h, num_patches_w, dtype, device):
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
coords_h = coords_h / num_patches_h
coords_w = coords_w / num_patches_w
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
coords = coords.flatten(0, 1)
coords = 2.0 * coords - 1.0
return coords
class DINOv3ViTRopePositionEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, rope_theta, hidden_size, num_attention_heads, patch_size, device, dtype):
super().__init__()
self.base = rope_theta
self.head_dim = hidden_size // num_attention_heads
self.patch_size = patch_size
inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, pixel_values):
_, _, height, width = pixel_values.shape
num_patches_h = height // self.patch_size
num_patches_w = width // self.patch_size
patch_coords = get_patches_center_coordinates(num_patches_h, num_patches_w, dtype=torch.float32, device=pixel_values.device)
self.inv_freq = self.inv_freq.to(pixel_values.device)
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
angles = angles.flatten(1, 2)
angles = angles.tile(2)
cos = torch.cos(angles).to(dtype=pixel_values.dtype)
sin = torch.sin(angles).to(dtype=pixel_values.dtype)
return cos, sin
class DINOv3ViTEmbeddings(nn.Module):
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations):
super().__init__()
self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
self.mask_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype))
self.patch_embeddings = operations.Conv2d(
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
)
def forward(self, pixel_values, bool_masked_pos=None):
batch_size = pixel_values.shape[0]
patch_embeddings = self.patch_embeddings(pixel_values)
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
if bool_masked_pos is not None:
mask_token = comfy.ops.cast_to_input(self.mask_token, patch_embeddings)
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
cls_token = comfy.ops.cast_to_input(self.cls_token.expand(batch_size, -1, -1), patch_embeddings)
register_tokens = comfy.ops.cast_to_input(self.register_tokens.expand(batch_size, -1, -1), patch_embeddings)
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
return embeddings
class DINOv3ViTLayer(nn.Module):
def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size,
num_attention_heads, device, dtype, operations, gated_mlp_act="silu"):
super().__init__()
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations)
self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
if use_gated_mlp:
self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations, act=gated_mlp_act)
else:
self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations)
self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.attention(hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings)
hidden_states = self.layer_scale1(hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.layer_scale2(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class DINOv3ViTModel(nn.Module):
def __init__(self, config, dtype, device, operations):
super().__init__()
num_hidden_layers = config["num_hidden_layers"]
hidden_size = config["hidden_size"]
num_attention_heads = config["num_attention_heads"]
num_register_tokens = config["num_register_tokens"]
intermediate_size = config["intermediate_size"]
layer_norm_eps = config["layer_norm_eps"]
num_channels = config["num_channels"]
patch_size = config["patch_size"]
rope_theta = config["rope_theta"]
use_gated_mlp = config.get("use_gated_mlp", False)
gated_mlp_act = config.get("gated_mlp_act", "silu")
self.embeddings = DINOv3ViTEmbeddings(
hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size,
dtype=dtype, device=device, operations=operations
)
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(
rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device
)
self.layer = nn.ModuleList([
DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=use_gated_mlp, mlp_bias=True,
intermediate_size=intermediate_size, num_attention_heads=num_attention_heads,
dtype=dtype, device=device, operations=operations, gated_mlp_act=gated_mlp_act)
for _ in range(num_hidden_layers)])
self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def forward(self, pixel_values, bool_masked_pos=None, **kwargs):
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
position_embeddings = self.rope_embeddings(pixel_values)
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, position_embeddings=position_embeddings)
if kwargs.get("skip_norm_elementwise", False):
sequence_output = F.layer_norm(hidden_states, hidden_states.shape[-1:])
else:
norm = self.norm.to(hidden_states.device)
sequence_output = norm(hidden_states)
pooled_output = sequence_output[:, 0, :]
return sequence_output, None, pooled_output, None

View File

@ -240,6 +240,16 @@ class Flux2(LatentFormat):
def process_out(self, latent):
return latent
class TripoSplat(LatentFormat):
# Sequence latent (B, 8192, 16) the camera token rides alongside as a second nested latent
latent_channels = 16
def process_in(self, latent):
return latent
def process_out(self, latent):
return latent
class Mochi(LatentFormat):
latent_channels = 12
latent_dimensions = 3

View File

@ -38,6 +38,8 @@ class ChromaRadianceParams(ChromaParams):
# None means use the same dtype as the model.
nerf_embedder_dtype: Optional[torch.dtype]
use_x0: bool
# Use sequential txt_ids instead of zeros
use_sequential_txt_ids: bool
class ChromaRadiance(Chroma):
"""
@ -162,6 +164,9 @@ class ChromaRadiance(Chroma):
if params.use_x0:
self.register_buffer("__x0__", torch.tensor([]))
if params.use_sequential_txt_ids:
self.register_buffer("__sequential__", torch.tensor([]))
@property
def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear":
@ -313,6 +318,9 @@ class ChromaRadiance(Chroma):
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
# Radiance after 2026-05-22 uses sequential txt_ids instead of zeros
if params.use_sequential_txt_ids:
txt_ids[:, :, 0] = torch.arange(context.shape[1], device=x.device, dtype=x.dtype).unsqueeze(0).expand(bs, -1)
img_out = self.forward_orig(
img,

View File

@ -14,15 +14,7 @@ from torchvision import transforms
import comfy.patcher_extension
from comfy.ldm.modules.attention import optimized_attention
import comfy.ldm.common_dit
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
) -> torch.Tensor:
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
return t_out
import comfy.quant_ops
# ---------------------- Feed Forward Network -----------------------
@ -173,8 +165,7 @@ class Attention(nn.Module):
k = self.k_norm(k)
v = self.v_norm(v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
q = apply_rotary_pos_emb(q, rope_emb)
k = apply_rotary_pos_emb(k, rope_emb)
q, k = comfy.quant_ops.ck.apply_rope_split_half(q, k, rope_emb)
return q, k, v
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)

View File

@ -5,6 +5,7 @@ import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
import comfy.quant_ops
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
@ -19,15 +20,6 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
return out.to(dtype=torch.float32, device=pos.device)
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
rot_dim = freqs_cis.shape[-1]
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
cos_ = freqs_cis[0]
sin_ = freqs_cis[1]
x1, x2 = x.chunk(2, dim=-1)
x_rotated = torch.cat((-x2, x1), dim=-1)
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
class ErnieImageEmbedND3(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: tuple):
super().__init__()
@ -37,8 +29,16 @@ class ErnieImageEmbedND3(nn.Module):
def forward(self, ids: torch.Tensor) -> torch.Tensor:
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2]
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim]
cos_ = emb[0]
sin_ = emb[1]
N = cos_.shape[-1]
half = N // 2
cos_top = cos_[..., :half].repeat_interleave(2, dim=-1)
sin_top = sin_[..., :half].repeat_interleave(2, dim=-1)
cos_bot = cos_[..., half:].repeat_interleave(2, dim=-1)
sin_bot = sin_[..., half:].repeat_interleave(2, dim=-1)
rot = torch.stack([cos_top, -sin_top, sin_bot, cos_bot], dim=-1)
return rot.reshape(*rot.shape[:-1], 2, 2).unsqueeze(2)
class ErnieImagePatchEmbedDynamic(nn.Module):
def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None):
@ -115,8 +115,7 @@ class ErnieImageAttention(nn.Module):
key = self.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query, key = comfy.quant_ops.ck.apply_rope_split_half(query, key, image_rotary_emb)
q_flat = query.reshape(B, S, -1)
k_flat = key.reshape(B, S, -1)
@ -274,7 +273,7 @@ class ErnieImageModel(nn.Module):
image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1)
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype)
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))
del image_ids, text_ids
sample = self.time_proj(timesteps).to(dtype)

View File

@ -4,7 +4,7 @@ from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
import logging
import comfy.quant_ops
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
@ -44,21 +44,15 @@ def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
try:
import comfy.quant_ops
q_apply_rope = comfy.quant_ops.ck.apply_rope
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
def apply_rope(xq, xk, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope(xq, xk, freqs_cis)
else:
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
def apply_rope1(x, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope1(x, freqs_cis)
else:
return q_apply_rope1(x, freqs_cis)
except:
logging.warning("No comfy kitchen, using old apply_rope functions.")
apply_rope = _apply_rope
apply_rope1 = _apply_rope1
def apply_rope(xq, xk, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope(xq, xk, freqs_cis)
else:
return comfy.quant_ops.ck.apply_rope(xq, xk, freqs_cis)
def apply_rope1(x, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope1(x, freqs_cis)
else:
return comfy.quant_ops.ck.apply_rope1(x, freqs_cis)

View File

@ -51,15 +51,6 @@ class FeedForward(nn.Module):
return hidden_states
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape)
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
super().__init__()

View File

@ -0,0 +1,199 @@
# TripoSplat 3D gaussian container. Operates on already-decoded
# tensors and exposes them as render-ready tensors (render_tensors) for the generic SPLAT type.
import torch
import torch.nn.functional as F
import comfy.model_management
class GaussianModel:
def __init__(self, aabb: list, sh_degree: int = 0, mininum_kernel_size: float = 0.0,
scaling_bias: float = 0.01, opacity_bias: float = 0.1,
scaling_activation: str = "exp", device=None):
self.sh_degree = sh_degree
self.mininum_kernel_size = mininum_kernel_size
self.scaling_bias = scaling_bias
self.opacity_bias = opacity_bias
self.device = device
self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
if scaling_activation == "exp":
self._scaling_activation = torch.exp
self._inverse_scaling_activation = torch.log
elif scaling_activation == "softplus":
self._scaling_activation = F.softplus
self._inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x))
self._opacity_activation = torch.sigmoid
self._inverse_opacity_activation = lambda x: torch.log(x / (1 - x))
self.scale_bias = self._inverse_scaling_activation(torch.tensor(self.scaling_bias)).to(self.device)
self.rots_bias = torch.zeros(4, device=self.device)
self.rots_bias[0] = 1
self.opacity_bias_val = self._inverse_opacity_activation(torch.tensor(self.opacity_bias)).to(self.device)
self._storage = {}
def _get_store(self, name):
return self._storage.get(name)
def _set_store(self, name, value):
self._storage[name] = value
@property
def _xyz(self):
return self._get_store("_xyz")
@_xyz.setter
def _xyz(self, value):
if value is None:
self._set_store("_xyz", None)
self._set_store("xyz", None)
return
self._set_store("_xyz", value)
self._set_store("xyz", value * self.aabb[None, 3:] + self.aabb[None, :3])
@property
def get_xyz(self):
return self._get_store("xyz")
@property
def _features_dc(self):
return self._get_store("_features_dc")
@_features_dc.setter
def _features_dc(self, value):
self._set_store("_features_dc", value)
@property
def _opacity(self):
return self._get_store("_opacity")
@_opacity.setter
def _opacity(self, value):
if value is None:
self._set_store("_opacity", None)
self._set_store("opacity", None)
return
self._set_store("_opacity", value)
self._set_store("opacity", self._opacity_activation(value + self.opacity_bias_val))
@property
def get_opacity(self):
return self._get_store("opacity")
@property
def _scaling(self):
return self._get_store("_scaling")
@_scaling.setter
def _scaling(self, value):
if value is None:
self._set_store("_scaling", None)
self._set_store("scaling", None)
return
self._set_store("_scaling", value)
s = self._scaling_activation(value + self.scale_bias)
s = torch.square(s) + self.mininum_kernel_size ** 2
self._set_store("scaling", torch.sqrt(s))
@property
def get_scaling(self):
return self._get_store("scaling")
@property
def _rotation(self):
return self._get_store("_rotation")
@_rotation.setter
def _rotation(self, value):
self._set_store("_rotation", value)
_DEFAULT_TRANSFORM = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
def render_tensors(self):
# Render-ready (activated, world-space) tensors for the generic SPLAT type. The axis transform
# (a 3x3 rotation, object frame -> viewer Y-up) is baked into positions and rotations.
# Returns float tensors on the intermediate device: positions (N,3), scales (N,3) linear,
# rotations (N,4) wxyz, opacities (N,1) in [0,1], sh (N,K,3) coefficients.
xyz = self.get_xyz.float()
scaling = self.get_scaling.float()
opacity = self.get_opacity.float()
rotation = (self._rotation + self.rots_bias[None, :]).float()
sh = self._features_dc.float() # (N, K, 3)
T = torch.as_tensor(self._DEFAULT_TRANSFORM, dtype=torch.float32, device=xyz.device)
xyz = xyz @ T.T
rotation = _matrix_to_quat(torch.matmul(T, _quat_to_matrix(rotation)))
rotation = rotation / torch.linalg.norm(rotation, dim=-1, keepdim=True)
out_device = comfy.model_management.intermediate_device()
return (
xyz.to(out_device).contiguous(), scaling.to(out_device).contiguous(),
rotation.to(out_device).contiguous(), opacity.to(out_device).contiguous(),
sh.to(out_device).contiguous(),
)
def _quat_to_matrix(q):
q = q / torch.linalg.norm(q, dim=-1, keepdim=True)
w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
R = torch.stack([
1 - 2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y),
2*(x*y + w*z), 1 - 2*(x*x + z*z), 2*(y*z - w*x),
2*(x*z - w*y), 2*(y*z + w*x), 1 - 2*(x*x + y*y),
], dim=-1).reshape(-1, 3, 3)
return R
def _matrix_to_quat(R):
trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
q = torch.zeros((R.shape[0], 4), dtype=R.dtype, device=R.device)
s = torch.sqrt(torch.clamp(trace + 1, min=0)) * 2
q[:, 0] = 0.25 * s
denom = torch.where(s != 0, s, torch.ones_like(s))
q[:, 1] = (R[:, 2, 1] - R[:, 1, 2]) / denom
q[:, 2] = (R[:, 0, 2] - R[:, 2, 0]) / denom
q[:, 3] = (R[:, 1, 0] - R[:, 0, 1]) / denom
m01 = (R[:, 0, 0] >= R[:, 1, 1]) & (R[:, 0, 0] >= R[:, 2, 2]) & (s == 0)
s1 = torch.sqrt(torch.clamp(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2], min=0)) * 2
q[m01, 0] = (R[m01, 2, 1] - R[m01, 1, 2]) / s1[m01]
q[m01, 1] = 0.25 * s1[m01]
q[m01, 2] = (R[m01, 0, 1] + R[m01, 1, 0]) / s1[m01]
q[m01, 3] = (R[m01, 0, 2] + R[m01, 2, 0]) / s1[m01]
m11 = (R[:, 1, 1] > R[:, 0, 0]) & (R[:, 1, 1] >= R[:, 2, 2]) & (s == 0)
s2 = torch.sqrt(torch.clamp(1 + R[:, 1, 1] - R[:, 0, 0] - R[:, 2, 2], min=0)) * 2
q[m11, 0] = (R[m11, 0, 2] - R[m11, 2, 0]) / s2[m11]
q[m11, 1] = (R[m11, 0, 1] + R[m11, 1, 0]) / s2[m11]
q[m11, 2] = 0.25 * s2[m11]
q[m11, 3] = (R[m11, 1, 2] + R[m11, 2, 1]) / s2[m11]
m21 = (R[:, 2, 2] > R[:, 0, 0]) & (R[:, 2, 2] > R[:, 1, 1]) & (s == 0)
s3 = torch.sqrt(torch.clamp(1 + R[:, 2, 2] - R[:, 0, 0] - R[:, 1, 1], min=0)) * 2
q[m21, 0] = (R[m21, 1, 0] - R[m21, 0, 1]) / s3[m21]
q[m21, 1] = (R[m21, 0, 2] + R[m21, 2, 0]) / s3[m21]
q[m21, 2] = (R[m21, 1, 2] + R[m21, 2, 1]) / s3[m21]
q[m21, 3] = 0.25 * s3[m21]
return q / torch.linalg.norm(q, dim=-1, keepdim=True)
def build_gaussian_models(decoder, points_pred: dict, pred: dict):
# Assemble GaussianModels from the elastic decoder layout. decoder is the ElasticGaussianFixedlenDecoder
# (carries layout / rep_config / _get_offset)
x = points_pred
offset = decoder._get_offset(pred['features'])
h = pred["features"]
ret = []
for i in range(h.shape[0]):
g = GaussianModel(
sh_degree=0,
aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
mininum_kernel_size=decoder.rep_config['filter_kernel_size_3d'],
scaling_bias=decoder.rep_config['scaling_bias'],
opacity_bias=decoder.rep_config['opacity_bias'],
scaling_activation=decoder.rep_config['scaling_activation'],
device=h.device,
)
_x = x["points"][i, :, None, :]
for k, v in decoder.layout.items():
if k == '_xyz':
setattr(g, k, (offset[i] + _x).flatten(0, 1))
elif k in ('_xyz_center', '_offset_scale'):
continue
else:
feats = h[i][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
setattr(g, k, feats * decoder.rep_config['lr'][k])
ret.append(g)
return ret

View File

@ -0,0 +1,326 @@
# TripoSplat flow-matching denoiser (LatentSeqMMFlowModel). Registered as a ModelType.FLOW arch and
# driven by the standard KSampler; jointly denoises the (B, 8192, 16) latent and a (B, 1, 5) camera token
# carried as a 2-element nested latent.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.model_management
import comfy.patcher_extension
import comfy.rmsnorm
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope
class MultiHeadRMSNorm(nn.Module):
def __init__(self, dim, heads, dtype=None, device=None):
super().__init__()
self.gamma = nn.Parameter(torch.empty(heads, dim, dtype=dtype, device=device))
def forward(self, x):
x = comfy.rmsnorm.rms_norm(x)
return x * comfy.model_management.cast_to(self.gamma, x.dtype, x.device)
# Positional embeddings
class RePo3DRotaryEmbedding(nn.Module):
def __init__(self, model_channels, num_heads, head_dim, repo_hidden_ratio=0.125, max_freq=16.0,
dtype=None, device=None, operations=None):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
repo_hidden_size = int(model_channels * repo_hidden_ratio)
self.norm = operations.LayerNorm(model_channels, dtype=dtype, device=device)
self.gate_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device)
self.content_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device)
self.act = nn.SiLU()
self.final_map = operations.Linear(repo_hidden_size, 3 * num_heads, bias=False, dtype=dtype, device=device)
self.dim_0 = 2 * (head_dim // 6)
self.dim_1 = 2 * (head_dim // 6)
self.dim_2 = head_dim - self.dim_0 - self.dim_1
dims = [self.dim_0, self.dim_1, self.dim_2]
freqs_list = []
for d in dims:
freq_dim = d // 2
freqs_list.append(torch.linspace(1.0, float(max_freq), steps=freq_dim, dtype=torch.float32))
self.freqs_0 = nn.Parameter(freqs_list[0])
self.freqs_1 = nn.Parameter(freqs_list[1])
self.freqs_2 = nn.Parameter(freqs_list[2])
def forward(self, hidden_states):
h = self.norm(hidden_states)
feat = self.act(self.gate_map(h)) * self.content_map(h)
out = self.final_map(feat)
B, L, _ = out.shape
delta_pos = out.reshape(B, L, self.num_heads, 3)
f0 = comfy.model_management.cast_to(self.freqs_0, torch.float32, out.device)
f1 = comfy.model_management.cast_to(self.freqs_1, torch.float32, out.device)
f2 = comfy.model_management.cast_to(self.freqs_2, torch.float32, out.device)
ang_0 = delta_pos[..., 0].unsqueeze(-1) * f0 * torch.pi
ang_1 = delta_pos[..., 1].unsqueeze(-1) * f1 * torch.pi
ang_2 = delta_pos[..., 2].unsqueeze(-1) * f2 * torch.pi
ang = torch.cat([ang_0, ang_1, ang_2], dim=-1).float() # (B, L, heads, head_dim/2)
cos, sin = ang.cos(), ang.sin()
return torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*ang.shape, 2, 2)
class PcdAbsolutePositionEmbedder(nn.Module):
# Sinusoidal absolute position embedding. Two fixed schedules are used in TripoSplat:
# "pow2" (flow-model latent anchors) and "log2" (octree / gaussian decoders).
def __init__(self, channels: int, in_channels: int = 3, max_res: int = 16, schedule: str = "pow2"):
super().__init__()
self.channels = channels
self.in_channels = in_channels
self.max_res = max_res
self.schedule = schedule
self.freq_dim = channels // in_channels // 2
def _freqs(self, device):
if self.schedule == "pow2":
freqs_2exp = torch.arange(self.max_res, dtype=torch.float32, device=device)
res_dim = max(0, self.freq_dim - self.max_res)
freqs_res = (torch.arange(res_dim, dtype=torch.float32, device=device) / max(res_dim, 1) * self.max_res
if res_dim > 0 else torch.empty(0, device=device))
freqs = torch.cat([freqs_2exp, freqs_res], dim=0)[:self.freq_dim]
return torch.pow(2.0, freqs) * 2.0 # *2 folds this schedule's 2*pi into the shared *pi below
logs = torch.linspace(0.0, float(self.max_res), steps=self.freq_dim, dtype=torch.float32, device=device)
return torch.pow(2.0, logs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
x = x.float()
*dims, D = x.shape
out = torch.outer(x.reshape(-1), self._freqs(x.device)) * torch.pi
out = torch.cat([out.sin(), out.cos()], dim=-1).reshape(*dims, -1)
if out.shape[-1] < self.channels:
out = torch.cat([out, torch.zeros(*dims, self.channels - out.shape[-1],
device=out.device, dtype=out.dtype)], dim=-1)
return out.to(orig_dtype)
def attention(q, k, v, transformer_options=None):
# q, k, v: (B, L, heads, dim) -> (B, L, heads, dim). Shared optimized_attention call convention.
out = optimized_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), heads=q.shape[2],
skip_reshape=True, skip_output_reshape=True, low_precision_attention=False,
transformer_options=transformer_options)
return out.transpose(1, 2)
# Transformer building blocks
class MLP(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(in_channels, hidden_channels, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(hidden_channels, out_channels, dtype=dtype, device=device),
)
def forward(self, x):
return self.mlp(x)
class RopeMultiHeadAttention(nn.Module):
def __init__(self, channels, num_heads, qkv_bias=True, qk_rms_norm=False, use_rope=False,
dtype=None, device=None, operations=None):
super().__init__()
self.channels = channels
self.num_heads = num_heads
self.head_dim = channels // num_heads
self.qk_rms_norm = qk_rms_norm
self.use_rope = use_rope
self.qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
if self.qk_rms_norm:
self.q_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.k_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.out = operations.Linear(channels, channels, dtype=dtype, device=device)
def forward(self, x, rope_emb=None, transformer_options=None):
B, L, C = x.shape
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(2)
if self.use_rope:
q, k = apply_rope(q, k, rope_emb)
if self.qk_rms_norm:
q = self.q_norm(q)
k = self.k_norm(k)
h = attention(q, k, v, transformer_options) # (B, L, heads, dim)
return self.out(h.reshape(B, L, C))
class UnifiedTransformerBlock(nn.Module):
def __init__(self, channels, num_heads, mlp_ratio=4.0,
use_rope=False, qk_rms_norm=False, qkv_bias=True,
modulation=True, share_mod=False,
dtype=None, device=None, operations=None):
super().__init__()
self.modulation = modulation
self.share_mod = share_mod
self.norm1 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device)
self.attn = RopeMultiHeadAttention(channels, num_heads=num_heads,
qkv_bias=qkv_bias, use_rope=use_rope, qk_rms_norm=qk_rms_norm,
dtype=dtype, device=device, operations=operations)
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
if modulation:
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
self.shift_table = nn.Parameter(torch.empty(1, 6 * channels, dtype=dtype, device=device))
def forward(self, x, mod=None, rotary_emb=None, transformer_options=None):
if self.modulation:
if not self.share_mod:
mod = self.adaLN_modulation(mod)
mod = mod + comfy.model_management.cast_to(self.shift_table, mod.dtype, mod.device)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
x = torch.addcmul(x, self.attn(h, rope_emb=rotary_emb, transformer_options=transformer_options), gate_msa.unsqueeze(1))
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
else:
x = x + self.attn(self.norm1(x), rope_emb=rotary_emb, transformer_options=transformer_options)
x = x + self.mlp(self.norm2(x))
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
emb = self.timestep_embedding(t, self.frequency_embedding_size)
return self.mlp(emb.to(self.mlp[0].weight.dtype))
class LatentSeqMMFlowModel(nn.Module):
def __init__(self, image_model=None, q_token_length=8192, in_channels=16, model_channels=1024,
cond_channels=1280, out_channels=16, num_blocks=24, num_refiner_blocks=2,
num_heads=None, num_head_channels=64, cam_channels=5, cond2_channels=128,
mlp_ratio=4, share_mod=True, qk_rms_norm=True,
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
self.q_token_length = q_token_length
self.in_channels = in_channels
self.cam_channels = cam_channels
self.model_channels = model_channels
self.cond_channels = cond_channels
self.cond2_channels = cond2_channels
self.out_channels = out_channels
self.num_blocks = num_blocks
self.num_refiner_blocks = num_refiner_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.share_mod = share_mod
self.qk_rms_norm = qk_rms_norm
factory_kwargs = dict(dtype=dtype, device=device)
op_kwargs = dict(operations=operations, **factory_kwargs)
self.t_embedder = TimestepEmbedder(model_channels, **op_kwargs)
if share_mod:
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, **factory_kwargs))
self.input_layer = operations.Linear(in_channels, model_channels, **factory_kwargs)
self.cond_embedder = operations.Linear(cond_channels, model_channels, **factory_kwargs)
self.cond_embedder2 = operations.Linear(cond2_channels, model_channels, **factory_kwargs) if cond2_channels is not None else None
# Fixed Sobol (low-discrepancy) 3D anchor positions for the latent tokens, used as positional encoding.
# The embedder is parameter-free and the anchors are fixed, precompute once.
sobol_seq = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123).draw(q_token_length)
pos_emb = PcdAbsolutePositionEmbedder(model_channels)(sobol_seq.unsqueeze(0))
self.register_buffer("pos_emb", pos_emb, persistent=False)
# RePo3DRotaryEmbedding layers for the refiner and main blocks
repo_kwargs = dict(num_heads=self.num_heads, head_dim=num_head_channels, **op_kwargs)
self.noise_repo_layers = nn.ModuleList(
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)])
self.context_repo_layers = nn.ModuleList(
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)])
self.repo_layers = nn.ModuleList(
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_blocks)])
# Refiner blocks
block_kwargs = dict(num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, use_rope=True, qk_rms_norm=self.qk_rms_norm, **op_kwargs)
self.noise_refiner = nn.ModuleList(
[UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_refiner_blocks)])
self.context_refiner = nn.ModuleList(
[UnifiedTransformerBlock(model_channels, modulation=False, **block_kwargs) for _ in range(num_refiner_blocks)])
self.cam_refiner = MLP(self.cam_channels, model_channels, model_channels, **op_kwargs)
self.blocks = nn.ModuleList(
[UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_blocks)])
self.shift_table = nn.Parameter(torch.empty(1, 2, model_channels, **factory_kwargs))
self.out_layer = operations.Linear(model_channels, out_channels, **factory_kwargs)
self.cam_out_layer = operations.Linear(model_channels, cam_channels, **factory_kwargs)
def forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, t, context, ref_latents, transformer_options, **kwargs)
def _forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs):
# x is the unpacked nested latent: [latent (B,8192,in_channels), camera (B,1,cam_channels)].
# context == feature1.
z, camera = x[0], x[1]
feat1 = context
h_x = self.input_layer(z)
h_cond = self.cond_embedder(feat1)
if ref_latents is not None and self.cond_embedder2 is not None:
# Flatten the Flux2 VAE latent (B,128,h,w) to a token sequence and front-pad to feat1's length
# (the pad count = feat1's prefix tokens: DINOv3 cls + registers), then add to the context.
feat2 = ref_latents[0].flatten(2).transpose(1, 2)
feat2 = F.pad(feat2, (0, 0, feat1.shape[1] - feat2.shape[1], 0))
h_cond = h_cond + self.cond_embedder2(feat2.to(h_cond.dtype))
t_emb = self.t_embedder(t)
t_mod = self.adaLN_modulation(t_emb) if self.share_mod else t_emb
h_x = h_x + self.pos_emb.to(z)
for i, block in enumerate(self.noise_refiner):
h_x = block(h_x, mod=t_mod, rotary_emb=self.noise_repo_layers[i](h_x), transformer_options=transformer_options)
for i, block in enumerate(self.context_refiner):
h_cond = block(h_cond, mod=None, rotary_emb=self.context_repo_layers[i](h_cond), transformer_options=transformer_options)
cam = camera.to(z)
h_cam = self.cam_refiner(cam)
h = torch.cat([h_x, h_cond, h_cam], dim=1)
for i, block in enumerate(self.blocks):
h = block(h, mod=t_mod, rotary_emb=self.repo_layers[i](h), transformer_options=transformer_options)
h_x = F.layer_norm(h[:, :z.shape[1]].float(), h.shape[-1:]).to(z)
h_cam = F.layer_norm(h[:, -cam.shape[1]:].float(), h.shape[-1:]).to(z)
shift, scale = (comfy.model_management.cast_to(self.shift_table, t_emb.dtype, t_emb.device) + t_emb.unsqueeze(1)).chunk(2, dim=1)
scale = 1 + scale
h_x = torch.addcmul(shift, h_x, scale)
h_cam = torch.addcmul(shift, h_cam, scale)
return self.out_layer(h_x), self.cam_out_layer(h_cam)

View File

@ -0,0 +1,91 @@
# Live preview for TripoSplat: decode an x0 estimate into a coarse gaussian splat and render it with a perspective orbit camera.
import numpy as np
from PIL import Image
_C0 = 0.28209479177387814
_LATENT_TOKENS = 8192 # q_token_length
_LATENT_CH = 16 # in_channels
_OBJECT_TO_VIEWER = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], np.float32) # object frame -> viewer Y-up frame
def _view_matrix(yaw_deg, pitch_deg):
y, p = np.radians(yaw_deg), np.radians(pitch_deg)
Ry = np.array([[np.cos(y), 0, np.sin(y)], [0, 1, 0], [-np.sin(y), 0, np.cos(y)]], np.float32)
Rx = np.array([[1, 0, 0], [0, np.cos(p), -np.sin(p)], [0, np.sin(p), np.cos(p)]], np.float32)
return Rx @ Ry
def render_splat(xyz, rgb, scale, opacity=None, yaw=35.0, pitch=30.0, size=320, min_px=2, gain=1.0,
max_px=9, min_opacity=0.0, fov=35.0, dist=2.2):
# Project gaussian centers with a perspective camera and paint each as a filled disk whose screen
# radius follows the gaussian's world-space scale, composited with a nearest-wins z-buffer.
# gain scales the footprint (≈ std spanned), `min_px`/`max_px` clamp the on-screen radius.
pts = xyz.astype(np.float32) @ _OBJECT_TO_VIEWER.T
v = pts @ _view_matrix(yaw, pitch).T
zc = v[:, 2] + dist
keep = zc > 1e-2
if opacity is not None and min_opacity > 0.0: # culls gaussians with very low opacity
keep = keep & (opacity > min_opacity)
v, zc, scale = v[keep], zc[keep], scale[keep]
col = (np.clip(rgb, 0, 1)[:, :3] * 255).astype(np.uint8)[keep]
if v.shape[0] == 0:
return Image.fromarray(np.zeros((size, size, 3), np.uint8))
f = (size / 2) / np.tan(np.radians(fov) / 2)
cx = size / 2 + f * v[:, 0] / zc
cy = size / 2 + f * v[:, 1] / zc
radius = np.clip(np.round(f * scale / zc * gain), min_px, max_px).astype(np.int32)
# Expand each splat to its disk pixels, bucketed by integer radius so it stays vectorized.
px, py, pz, pc = [], [], [], []
for r in range(int(radius.min()), int(radius.max()) + 1):
m = radius == r
if not m.any():
continue
dy, dx = np.mgrid[-r:r + 1, -r:r + 1]
disk = (dx * dx + dy * dy) <= r * r
ox, oy = dx[disk], dy[disk]
px.append((cx[m, None] + ox).ravel())
py.append((cy[m, None] + oy).ravel())
pz.append(np.repeat(zc[m], ox.size))
pc.append(np.repeat(col[m], ox.size, axis=0))
px, py = np.concatenate(px), np.concatenate(py)
pz, pc = np.concatenate(pz), np.concatenate(pc)
xi = np.clip(px, 0, size - 1).astype(np.int64)
yi = np.clip(py, 0, size - 1).astype(np.int64)
# Nearest-wins z-buffer: pack (quantized depth, source index), per-pixel min picks the closest
# splat, then decode the winning index back to its color.
pid = yi * size + xi
q = np.clip((pz * 1024.0).astype(np.int64), 0, (1 << 20) - 1) # near = small
key = (q << 32) | np.arange(pid.size, dtype=np.int64)
buf = np.full(size * size, 1 << 62, np.int64)
np.minimum.at(buf, pid, key)
img = np.zeros((size * size, 3), np.uint8)
hit = buf < (1 << 62)
img[hit] = pc[buf[hit] & 0xFFFFFFFF]
return Image.fromarray(img.reshape(size, size, 3))
def _extract_latent(x0):
# x0 from the sampler callback is the nested latent packed to (B, 1, TOKENS*CH + 1*5);
# the plain single-latent case is (B, TOKENS, CH). Return the (B, TOKENS, CH) latent stream.
if x0.ndim == 3 and x0.shape[1] == _LATENT_TOKENS and x0.shape[2] == _LATENT_CH:
return x0
flat = x0.reshape(x0.shape[0], -1)
return flat[:, :_LATENT_TOKENS * _LATENT_CH].reshape(x0.shape[0], _LATENT_TOKENS, _LATENT_CH)
def decode_x0_to_image(decoder, x0, cfg):
# Decode x0 at a coarse octree level / few gaussians and render a preview image.
latent = _extract_latent(x0)
fsm = decoder.first_stage_model
gaussian = fsm.decode(latent.to(decoder.device, decoder.vae_dtype),
num_gaussians=cfg.get("gaussians", 16384), level=cfg.get("level", 5))[0]
xyz = gaussian.get_xyz.float().cpu().numpy()
rgb = gaussian._features_dc.float().cpu().numpy()[:, 0, :] * _C0 + 0.5
scale = gaussian.get_scaling.float().cpu().numpy().max(axis=1) # per-splat world radius (largest axis)
opacity = gaussian.get_opacity.float().cpu().numpy()[:, 0]
return render_splat(xyz, rgb, scale, opacity=opacity, yaw=cfg.get("yaw", 35.0), pitch=cfg.get("pitch", 30.0),
size=cfg.get("size", 320), min_px=1, gain=1.0, max_px=cfg.get("point_size", 3),
min_opacity=0.01)

382
comfy/ldm/triposplat/vae.py Normal file
View File

@ -0,0 +1,382 @@
# TripoSplat gaussian decoder ("VAE"): an octree probability decoder picks point coords, then an
# elastic-gaussian decoder predicts per-point gaussian params. OctreeGaussianDecoder.decode() returns
# a Gaussian. The octree sampler uses the global torch RNG (no generator) like upstream, so seed it for repeatable decodes.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.model_management
import comfy.ops
from .gaussian import build_gaussian_models
from .model import MultiHeadRMSNorm, MLP, PcdAbsolutePositionEmbedder, attention
# Quasi-random sampling utilities (pure functions, dtype/device-agnostic)
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
def radical_inverse(base, n):
val = 0
inv_base = 1.0 / base
inv_base_n = inv_base
while n > 0:
digit = n % base
val += digit * inv_base_n
n //= base
inv_base_n *= inv_base
return val
def halton_sequence(dim, n):
return [radical_inverse(PRIMES[i], n) for i in range(dim)]
def hammersley_sequence(dim, n, num_samples):
return [n / num_samples] + halton_sequence(dim - 1, n)
def sample_probs(probs, counts, generator=None):
# Systematic resampling: distribute counts[r] draws across the P bins of row r
batch_shape = counts.shape
R = counts.numel()
P = probs.size(-1)
device = probs.device
probs = probs.reshape(R, P).to(torch.float32).clamp_min(0)
counts = counts.reshape(R).to(device=device, dtype=torch.long)
row_sums = probs.sum(1, keepdim=True)
probs = torch.where(row_sums == 0, probs.new_tensor(1.0 / P), probs / row_sums.clamp_min(1))
cdf = probs.cumsum(dim=1).clamp(max=1.0 - 1e-12)
Nmax = int(counts.max())
if Nmax == 0:
return counts.new_zeros(*batch_shape, P)
cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1)
grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax)
u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded)
idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1)
weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r]
out = torch.zeros(R, P, dtype=torch.float32, device=device)
out.scatter_add_(1, idx, weight)
return out.to(torch.long).view(*batch_shape, P)
class MultiHeadAttention(nn.Module):
def __init__(self, channels, num_heads, ctx_channels=None, type="self", qkv_bias=True, qk_rms_norm=False,
dtype=None, device=None, operations=None):
super().__init__()
assert channels % num_heads == 0
self.channels = channels
self.head_dim = channels // num_heads
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
self.num_heads = num_heads
self._type = type
self.qk_rms_norm = qk_rms_norm
if self._type == "self":
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
else:
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, dtype=dtype, device=device)
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, dtype=dtype, device=device)
if self.qk_rms_norm:
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.to_out = operations.Linear(channels, channels, dtype=dtype, device=device)
def forward(self, x, context=None):
B, L, C = x.shape
if self._type == "self":
q, k, v = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1).unbind(dim=2)
else:
Lkv = context.shape[1]
q = self.to_q(x).reshape(B, L, self.num_heads, -1)
k, v = self.to_kv(context).reshape(B, Lkv, 2, self.num_heads, -1).unbind(dim=2)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k = self.k_rms_norm(k)
h = attention(q, k, v)
return self.to_out(h.reshape(B, L, -1))
# Octree probability decoder
class LevelEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, max_period=1024,
dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
@staticmethod
def level_embedding(t, dim, max_period=1024):
half = dim // 2
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None] * 2 * torch.pi
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
emb = self.level_embedding(t, self.frequency_embedding_size, self.max_period)
return self.mlp(emb.to(self.mlp[0].weight.dtype))
class ModulatedTransformerCrossOnlyBlock(nn.Module):
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, share_mod=False,
qk_rms_norm_cross=True, qkv_bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.share_mod = share_mod
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads,
type="cross", qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
def forward(self, x, mod, context):
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
x = torch.addcmul(x, self.cross_attn(h, context), gate_msa.unsqueeze(1))
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
return x
class OctreeProbabilityFixedlenDecoder(nn.Module):
# Cross-attention transformer over octree coords -> per-node 8-way child occupancy logits.
def __init__(self, model_channels=1024, cond_channels=16, num_blocks=4, num_heads=16,
num_head_channels=64, mlp_ratio=4.0, share_mod=True,
qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
super().__init__()
self.model_channels = model_channels
self.cond_channels = cond_channels
self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.share_mod = share_mod
self.qk_rms_norm_cross = qk_rms_norm_cross
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
self.l_embedder = LevelEmbedder(model_channels, dtype=dtype, device=device, operations=operations)
if share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, dtype=dtype, device=device))
if cond_channels is not None:
self.blocks = nn.ModuleList([
ModulatedTransformerCrossOnlyBlock(
model_channels, ctx_channels=cond_channels, num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio, qk_rms_norm_cross=self.qk_rms_norm_cross,
share_mod=self.share_mod, dtype=dtype, device=device, operations=operations)
for _ in range(num_blocks)
])
self.out_proj = operations.Linear(model_channels, 8, dtype=dtype, device=device)
self.in_proj = operations.Linear(3, model_channels, dtype=dtype, device=device)
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
def forward(self, x, l, cond):
d = next(self.parameters()).dtype
B, L, _ = x.shape
h = self.in_proj(x.to(d)) + self.pos_embedder(x.reshape(-1, 3)).reshape(B, L, -1).to(d)
h = self.input_layer(h)
l_emb = self.l_embedder(l)
if self.share_mod:
l_emb = self.adaLN_modulation(l_emb)
cond = cond.to(d)
for block in self.blocks:
h = block(h, l_emb, cond)
h = F.layer_norm(h.float(), h.shape[-1:]).to(d)
logits = self.out_proj(h)
return {"logits": logits, "probs": torch.softmax(logits, dim=-1)}
@staticmethod
def sample(model, cond, num_points, level, temperature=1.0, generator=None):
B = cond.shape[0]
device = cond.device
child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]],
dtype=torch.long, device=device)
prev_coords_int = torch.zeros(B, 1, 3, dtype=torch.long, device=device)
prev_counts = torch.full((B, 1), num_points, dtype=torch.long, device=device)
prev_log_probs = torch.zeros(B, 1, dtype=torch.float32, device=device)
batch_indices_range = torch.arange(B, device=device).unsqueeze(1)
for lv in range(1, level + 1):
res_p = 1 << (lv - 1)
res = 1 << lv
parent_coords_norm = (prev_coords_int.to(torch.float32) + 0.5) / res_p
res_tensor = torch.full((B,), res, dtype=torch.long, device=device)
pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature
pred_probs = torch.softmax(pred_logits, dim=-1)
pred_log_probs = torch.log_softmax(pred_logits, dim=-1)
sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2)
pred_log_probs = pred_log_probs.flatten(1, 2)
prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1)
child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2)
mask = sampled > 0
max_valid = mask.sum(dim=1).max().item()
scatter_indices = mask.cumsum(dim=1) - 1
valid_scatter_indices = scatter_indices[mask]
valid_batch_indices = batch_indices_range.expand_as(mask)[mask]
next_prev_coords_int = torch.zeros(B, max_valid, 3, dtype=child_coords_int.dtype, device=device)
next_prev_coords_int[valid_batch_indices, valid_scatter_indices] = child_coords_int[mask]
next_prev_counts = torch.zeros(B, max_valid, dtype=sampled.dtype, device=device)
next_prev_counts[valid_batch_indices, valid_scatter_indices] = sampled[mask]
next_prev_log_probs = torch.zeros(B, max_valid, dtype=prev_log_probs.dtype, device=device)
next_prev_log_probs[valid_batch_indices, valid_scatter_indices] = (prev_log_probs_expanded + pred_log_probs)[mask]
prev_coords_int = next_prev_coords_int
prev_counts = next_prev_counts
prev_log_probs = next_prev_log_probs
res = 1 << level
prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points)
coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1)
rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device)
coords_norm = (coords_int.to(torch.float32) + rand) / res
return {"points": coords_norm, "log_probs": prev_log_probs}
# Elastic gaussian decoder
class TransformerCrossBlock(nn.Module):
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0,
qk_rms_norm=True, qk_rms_norm_cross=True, qkv_bias=True,
dtype=None, device=None, operations=None):
super().__init__()
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.self_attn = MultiHeadAttention(channels, num_heads=num_heads, type="self", qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm, dtype=dtype, device=device, operations=operations)
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, type="cross",
qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
def forward(self, x, context):
x = x + self.self_attn(self.norm1(x))
x = x + self.cross_attn(self.norm2(x), context)
x = x + self.mlp(self.norm3(x))
return x
class ElasticGaussianFixedlenDecoder(nn.Module):
# Cross-attention transformer over sampled octree points -> per-point gaussian params.
def __init__(self, in_channels=3, model_channels=1024, cond_channels=16, num_blocks=16, num_heads=16,
num_head_channels=64, mlp_ratio=4.0, *, representation_config=None,
qk_rms_norm=True, qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
super().__init__()
self.rep_config = representation_config or dict(
lr=dict(_xyz=1.0, _features_dc=1.0, _opacity=1.0, _scaling=1.0, _rotation=0.1),
perturb_offset=True, perturbe_size=1.5, offset_scale=0.05, num_gaussians=32,
filter_kernel_size_3d=0.0009, scaling_bias=0.004, opacity_bias=0.1,
scaling_activation="softplus",
)
self.out_channels = self._calc_layout()
self.model_channels = model_channels
self.cond_channels = cond_channels
self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
if cond_channels is not None:
self.blocks = nn.ModuleList([
TransformerCrossBlock(model_channels, ctx_channels=cond_channels,
num_heads=self.num_heads, mlp_ratio=self.mlp_ratio,
qk_rms_norm=qk_rms_norm, qk_rms_norm_cross=qk_rms_norm_cross,
dtype=dtype, device=device, operations=operations)
for _ in range(num_blocks)
])
self.in_proj = operations.Linear(in_channels, model_channels, dtype=dtype, device=device)
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
self.out_proj = operations.Linear(model_channels, self.out_channels, dtype=dtype, device=device)
self._build_perturbation()
def _calc_layout(self):
ng = self.rep_config['num_gaussians']
self.layout = {
'_xyz': {'shape': (ng, 3), 'size': ng * 3},
'_features_dc': {'shape': (ng, 1, 3), 'size': ng * 3},
'_scaling': {'shape': (ng, 3), 'size': ng * 3},
'_rotation': {'shape': (ng, 4), 'size': ng * 4},
'_opacity': {'shape': (ng, 1), 'size': ng},
}
self.layout['_offset_scale'] = {'shape': (ng, 1), 'size': ng}
start = 0
for k, v in self.layout.items():
v['range'] = (start, start + v['size'])
start += v['size']
return start
def _build_perturbation(self):
ng = self.rep_config['num_gaussians']
perturbation = torch.tensor([hammersley_sequence(3, i, ng) for i in range(ng)]).float()
perturbation = torch.atanh((perturbation * 2 - 1) / self.rep_config['perturbe_size'])
self.register_buffer('points_offset_perturbation', perturbation)
base = torch.tensor(self.rep_config['offset_scale'])
self.register_buffer('base_offset_scale', torch.log(torch.exp(base) - 1.0))
def _get_offset(self, h):
B = h.shape[0]
r = self.layout['_offset_scale']['range']
_offset_scale = F.softplus(
h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_offset_scale']['shape'])
+ comfy.model_management.cast_to(self.base_offset_scale, h.dtype, h.device))
r = self.layout['_xyz']['range']
offset = h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_xyz']['shape'])
offset = offset * self.rep_config['lr']['_xyz']
if self.rep_config['perturb_offset']:
offset = offset + comfy.model_management.cast_to(self.points_offset_perturbation, offset.dtype, offset.device)
offset = torch.tanh(offset) * 0.5 * self.rep_config['perturbe_size']
offset = offset * _offset_scale
return offset
def forward(self, x=None, cond=None):
pcd = x["points"]
d = next(self.parameters()).dtype
B, L, _ = pcd.shape
h = self.in_proj(pcd.to(d)) + self.pos_embedder(pcd.reshape(-1, 3)).reshape(B, L, -1).to(d)
h = self.input_layer(h)
cond = cond.to(d)
for block in self.blocks:
h = block(h, cond)
h = F.layer_norm(h.float(), h.shape[-1:]).to(h.dtype)
return {"features": self.out_proj(h)}
# Combined octree gaussian decoder (comfy first-stage model)
class OctreeGaussianDecoder(nn.Module):
_MAX_VOXEL_LEVEL = 8
def __init__(self, dtype=None, device=None, operations=None):
super().__init__()
if operations is None:
operations = comfy.ops.disable_weight_init
self.octree = OctreeProbabilityFixedlenDecoder(dtype=dtype, device=device, operations=operations)
self.gs = ElasticGaussianFixedlenDecoder(dtype=dtype, device=device, operations=operations)
@property
def gaussians_per_point(self) -> int:
return self.gs.rep_config['num_gaussians']
def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None):
# level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews.
# generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG.
level = self._MAX_VOXEL_LEVEL if level is None else level
num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point)
points_pred = OctreeProbabilityFixedlenDecoder.sample(
self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator,
)
pred = self.gs(x=points_pred, cond=latent)
return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item

View File

@ -4,6 +4,7 @@ import dataclasses
import torch
from typing import NamedTuple
import comfy_aimdo.host_buffer
from comfy.quant_ops import QuantizedTensor
@ -17,21 +18,18 @@ class TensorFileSlice(NamedTuple):
def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None):
if isinstance(tensor, QuantizedTensor):
if not isinstance(destination, QuantizedTensor):
return False
if tensor._layout_cls != destination._layout_cls:
return False
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream,
if not read_tensor_file_slice_into(tensor._qdata,
destination._qdata if destination is not None else None, stream=stream,
destination2=(destination2._qdata if destination2 is not None else None)):
return False
dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
if destination is not None:
dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
if destination2 is not None:
dst_orig_dtype = destination2._params.orig_dtype
destination2._params.copy_from(destination._params, non_blocking=True)
destination2._params.copy_from(destination._params if destination is not None else tensor._params, non_blocking=True)
destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype)
return True
@ -39,10 +37,15 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
if info is None:
return False
if destination is not None and destination.device.type != "cpu" and destination2 is None:
destination2 = destination
destination = None
file_obj = info.file_ref
if (destination.device.type != "cpu"
or file_obj is None
or destination.numel() * destination.element_size() < info.size
if (file_obj is None
or (destination is None and destination2 is None)
or (destination is not None and (destination.device.type != "cpu" or destination.numel() * destination.element_size() < info.size))
or (destination2 is not None and (destination2.device.type == "cpu" or destination2.numel() * destination2.element_size() < info.size))
or tensor.numel() * tensor.element_size() != info.size
or tensor.storage_offset() != 0
or not tensor.is_contiguous()):
@ -51,6 +54,14 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
if info.size == 0:
return True
if destination is None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
comfy_aimdo.host_buffer.read_file_to_device(file_obj, info.offset, info.size,
stream_ptr, destination2.data_ptr(),
destination2.device.index,
mark_cold=False)
return True
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
if hostbuf is not None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
@ -63,6 +74,9 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
device=None if destination2 is None else destination2.device.index)
return True
if not hasattr(file_obj, "seek") or not hasattr(file_obj, "readinto"):
return False
buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))

View File

@ -46,6 +46,7 @@ import comfy.ldm.wan.model_animate
import comfy.ldm.wan.ar_model
import comfy.ldm.wan.model_wandancer
import comfy.ldm.hunyuan3d.model
import comfy.ldm.triposplat.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.chroma_radiance.model
@ -1818,6 +1819,24 @@ class Hunyuan3Dv2_1(BaseModel):
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class TripoSplat(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.triposplat.model.LatentSeqMMFlowModel)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None) # DINOv3 token sequence -> cross-attention context.
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None) # Flux2 VAE image latent -> additive second conditioning.
if ref_latents is not None:
out['ref_latents'] = comfy.conds.CONDList(list(ref_latents))
latent_shapes = kwargs.get("latent_shapes", None) # {latent, camera} nested latent
if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
return out
class HiDream(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)

View File

@ -313,6 +313,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["use_x0"] = True
else:
dit_config["use_x0"] = False
if "{}__sequential__".format(key_prefix) in state_dict_keys: # sequential txt_ids
dit_config["use_sequential_txt_ids"] = True
else:
dit_config["use_sequential_txt_ids"] = False
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
@ -726,6 +730,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
return dit_config
if '{}cam_out_layer.weight'.format(key_prefix) in state_dict_keys and '{}repo_layers.0.final_map.weight'.format(key_prefix) in state_dict_keys: # TripoSplat
return {"image_model": "triposplat"}
if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1
return {"image_model": "hidream_o1"}

View File

@ -641,14 +641,17 @@ def free_pins(size, evict_active=False):
return freed_total
def ensure_pin_budget(size, evict_active=False):
shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available
if args.fast_disk:
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
else:
shortfall = size + max(comfy.memory_management.RAM_CACHE_HEADROOM / 2, 2048 * 1024 ** 2) - psutil.virtual_memory().available
if shortfall <= 0:
return True
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
return free_pins(to_free, evict_active=evict_active) >= shortfall
def ensure_pin_registerable(size, evict_active=False):
def ensure_pin_registerable(size, evict_active=True):
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
@ -658,10 +661,17 @@ def ensure_pin_registerable(size, evict_active=False):
shortfall += REGISTERABLE_PIN_HYSTERESIS
for loaded_model in reversed(current_loaded_models):
model = loaded_model.model
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
if model is not None and model.is_dynamic() and not model.model.dynamic_pins[model.load_device]["active"]:
shortfall -= model.unregister_inactive_pins(shortfall)
if shortfall <= 0:
return True
if evict_active:
for loaded_model in current_loaded_models:
model = loaded_model.model
if model is not None and model.is_dynamic() and model.model.dynamic_pins[model.load_device]["active"]:
shortfall -= model.unregister_inactive_pins(shortfall)
if shortfall <= 0:
return True
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
class LoadedModel:
@ -803,9 +813,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for x in can_unload_sorted:
i = x[-1]
memory_to_free = 1e32
if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None):
if not DISABLE_SMART_MEMORY or device is None:
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
if for_dynamic:
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size()
@ -817,6 +827,10 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
if not for_dynamic and pins_required > 0:
ensure_pin_budget(pins_required)
ensure_pin_registerable(pins_required)
if len(unloaded_model) > 0:
soft_empty_cache()
elif device is not None:
@ -879,15 +893,19 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
model_to_unload.model_finalizer.detach()
total_memory_required = {}
total_pins_required = {}
for loaded_model in models_to_load:
device = loaded_model.device
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
if not loaded_model.model.is_dynamic():
total_pins_required[device] = total_pins_required.get(device, 0) + loaded_model.model_memory()
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem,
device,
for_dynamic=free_for_dynamic)
for_dynamic=free_for_dynamic,
pins_required=total_pins_required.get(device, 0))
for device in total_memory_required:
if device != torch.device("cpu"):
@ -1283,7 +1301,6 @@ STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
STREAM_PIN_BUFFERS = {}
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
@ -1326,42 +1343,13 @@ def get_aimdo_cast_buffer(offload_stream, device):
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer
def get_pin_buffer(offload_stream):
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
if pin_buffer is None:
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3), mark_cold=False)
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
elif offload_stream is not None:
event = getattr(pin_buffer, "_comfy_event", None)
if event is not None:
event.synchronize()
delattr(pin_buffer, "_comfy_event")
return pin_buffer
def resize_pin_buffer(pin_buffer, size):
global TOTAL_PINNED_MEMORY
old_size = pin_buffer.size
if size <= old_size:
return True
growth = size - old_size
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
ensure_pin_budget(growth, evict_active=True)
ensure_pin_registerable(growth, evict_active=True)
try:
pin_buffer.extend(size=size, reallocate=True)
except RuntimeError:
return False
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
return True
def reset_cast_buffers():
global TOTAL_PINNED_MEMORY
global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0)
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS):
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
if offload_stream is not None:
offload_stream.synchronize()
synchronize()
@ -1370,20 +1358,24 @@ def reset_cast_buffers():
mmap_obj.bounce()
DIRTY_MMAPS.clear()
for pin_buffer in STREAM_PIN_BUFFERS.values():
TOTAL_PINNED_MEMORY -= pin_buffer.size
TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY)
for loaded_model in current_loaded_models:
model = loaded_model.model
if model is not None and model.is_dynamic():
model.model.dynamic_pins[model.load_device]["active"] = False
pin_state = model.model.dynamic_pins[model.load_device]
if pin_state["active"]:
*_, buckets = pin_state["weights"]
for size, bucket in list(buckets.items()):
bucket[:] = [ entry for entry in bucket if entry[-1] is not None ]
if not bucket:
del buckets[size]
pin_state["active"] = False
model.partially_unload_ram(1e30, subsets=[ "patches" ])
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0])
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0], [0], {})
STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()
STREAM_PIN_BUFFERS.clear()
soft_empty_cache()
def get_offload_stream(device):
@ -1436,7 +1428,7 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) if r is not None else [None] * len(tensors)
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
with wf_context:
for tensor in tensors:
@ -1448,9 +1440,10 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
mark_mmap_dirty(storage)
dest_view.copy_(tensor, non_blocking=non_blocking)
if dest_view is not None:
dest_view.copy_(tensor, non_blocking=non_blocking)
if dest2_view is not None:
dest2_view.copy_(dest_view, non_blocking=non_blocking)
dest2_view.copy_(tensor if dest_view is None else dest_view, non_blocking=non_blocking)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
@ -1723,6 +1716,13 @@ def is_device_xpu(device):
def is_device_cuda(device):
return is_device_type(device, 'cuda')
def set_torch_device(device):
"""Set the current device for the given torch device. Supports CUDA and XPU."""
if is_device_cuda(device):
torch.cuda.set_device(device)
elif is_device_xpu(device):
torch.xpu.set_device(device)
def is_directml_enabled():
global directml_enabled
if directml_enabled:

View File

@ -1721,8 +1721,8 @@ class ModelPatcherDynamic(ModelPatcher):
"""
if device not in self.model.dynamic_pins:
self.model.dynamic_pins[device] = {
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}),
"hostbufs_initialized": False,
"failed": False,
"active": False,
@ -1799,8 +1799,8 @@ class ModelPatcherDynamic(ModelPatcher):
pin_state = self.model.dynamic_pins[self.load_device]
if not pin_state["hostbufs_initialized"]:
hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size())
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0])
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0])
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {})
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {})
pin_state["hostbufs_initialized"] = True
pin_state["failed"] = False
pin_state["active"] = True
@ -1942,18 +1942,16 @@ class ModelPatcherDynamic(ModelPatcher):
return freed
def loaded_ram_size(self):
return (self.model.dynamic_pins[self.load_device]["weights"][0].size +
self.model.dynamic_pins[self.load_device]["patches"][0].size)
return (self.model.dynamic_pins[self.load_device]["weights"][0].size)
def pinned_memory_size(self):
return (self.model.dynamic_pins[self.load_device]["weights"][3][0] +
self.model.dynamic_pins[self.load_device]["patches"][3][0])
return (self.model.dynamic_pins[self.load_device]["weights"][3][0])
def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets:
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset]
split = stack_split[0]
while split >= 0:
module, offset = stack[split]
@ -1978,10 +1976,12 @@ class ModelPatcherDynamic(ModelPatcher):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets:
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset]
while len(stack) > 0:
module, offset = stack.pop()
size = module._pin.numel() * module._pin.element_size()
module._pin_balancer_entry[-1] = None
del module._pin_balancer_entry
del module._pin
hostbuf.truncate(offset, do_unregister=module._pin_registered)
stack_split[0] = min(stack_split[0], len(stack) - 1)

View File

@ -1,4 +1,5 @@
import comfy_aimdo.model_vbar
import comfy.memory_management
import comfy.model_management
import comfy.ops
@ -50,7 +51,17 @@ def prefetch_queue_pop(queue, device, module):
if hasattr(s, "_v"):
comfy_modules.append(s)
registerable_size = 0
for s in comfy_modules:
registerable_size += comfy.memory_management.vram_aligned_size([s.weight, s.bias])
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
registerable_size += lowvram_fn.memory_required()
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
if not comfy.model_management.args.fast_disk:
comfy.model_management.ensure_pin_registerable(registerable_size)
comfy.model_management.sync_stream(device, offload_stream)
queue[0] = (offload_stream, (prefetch, comfy_modules))

View File

@ -17,7 +17,7 @@ class MultiGPUThreadPool:
"""Persistent thread pool for multi-GPU work distribution.
Maintains one worker thread per extra GPU device. Each thread calls
torch.cuda.set_device() once at startup so that compiled kernel caches
set_torch_device() once at startup so that compiled kernel caches
(inductor/triton) stay warm across diffusion steps.
"""
@ -37,7 +37,7 @@ class MultiGPUThreadPool:
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
try:
torch.cuda.set_device(device)
comfy.model_management.set_torch_device(device)
except Exception as e:
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
while True:
@ -54,6 +54,8 @@ class MultiGPUThreadPool:
try:
result = fn(*args, **kwargs)
result_q.put((result, None))
except comfy.model_management.InterruptProcessingException as e:
result_q.put((None, e))
except Exception as e:
result_q.put((None, e))

View File

@ -76,8 +76,6 @@ except:
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@ -94,9 +92,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
offload_stream = None
cast_buffer = None
cast_buffer_offset = 0
stream_pin_hostbuf = None
stream_pin_offset = 0
stream_pin_queue = []
def ensure_offload_stream(module, required_size, check_largest):
nonlocal offload_stream
@ -130,22 +125,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
cast_buffer_offset += buffer_size
return buffer
def get_stream_pin_buffer_offset(buffer_size):
nonlocal stream_pin_hostbuf
nonlocal stream_pin_offset
if buffer_size == 0 or offload_stream is None:
return None
if stream_pin_hostbuf is None:
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
if stream_pin_hostbuf is None:
return None
offset = stream_pin_offset
stream_pin_offset += buffer_size
return offset
for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
@ -184,12 +163,18 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
if xfer_dest is None:
xfer_dest = get_cast_buffer(dest_size)
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream, xfer_dest2=None):
if xfer_source is not None:
if getattr(xfer_source, "is_lowvram_patch", False):
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
else:
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
if xfer_dest is not None:
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
xfer_source = [ xfer_dest ]
xfer_dest = xfer_dest2
xfer_dest2 = None
elif xfer_dest2 is not None:
xfer_source.prepare(xfer_dest2, stream, copy=True, commit=False)
return
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream, r2=xfer_dest2)
def handle_pin(m, pin, source, dest, subset="weights", size=None):
if pin is not None:
@ -198,19 +183,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
if signature is None:
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
pin = comfy.pinned_memory.get_pin(m, subset=subset)
if pin is not None:
if isinstance(source, list):
comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest)
else:
cast_maybe_lowvram_patch(source, pin, None)
cast_maybe_lowvram_patch([ pin ], dest, offload_stream)
return
if pin is None:
pin_offset = get_stream_pin_buffer_offset(size)
if pin_offset is not None:
stream_pin_queue.append((source, pin_offset, size, dest))
return
cast_maybe_lowvram_patch(source, dest, offload_stream)
cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest)
handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size)
@ -232,23 +205,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch
if stream_pin_offset > 0:
if stream_pin_hostbuf.size < stream_pin_offset:
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
for xfer_source, _, _, xfer_dest in stream_pin_queue:
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
return offload_stream
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf)
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
if isinstance(xfer_source, list):
comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest)
else:
cast_maybe_lowvram_patch(xfer_source, pin, None)
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
stream_pin_hostbuf._comfy_event = offload_stream.record_event()
return offload_stream

View File

@ -1,17 +1,55 @@
import bisect
import comfy.model_management
import comfy.memory_management
import comfy.utils
import comfy_aimdo.host_buffer
import comfy_aimdo.torch
import torch
from comfy.cli_args import args
def _add_to_bucket(module, buckets, size, priority):
bucket = buckets.setdefault(size, [])
entry = [-priority, 0, module]
entry[1] = id(entry)
bisect.insort(bucket, entry)
module._pin_balancer_entry = entry
def _steal_pin(module, stack, buckets, size, priority):
bucket = buckets.get(size)
if bucket is None:
return False
while bucket and bucket[-1][-1] is None:
bucket.pop()
if not bucket:
del buckets[size]
return False
if priority <= -bucket[-1][0]:
return False
*_, victim = bucket.pop()
module._pin = victim._pin
module._pin_registered = victim._pin_registered
module._pin_stack_index = victim._pin_stack_index
stack[module._pin_stack_index] = (module, stack[module._pin_stack_index][1])
victim._pin_registered = False
del victim._pin
del victim._pin_stack_index
del victim._pin_balancer_entry
_add_to_bucket(module, buckets, size, priority)
return True
def get_pin(module, subset="weights"):
pin = getattr(module, "_pin", None)
if pin is None or module._pin_registered or args.disable_pinned_memory:
return pin
_, _, stack_split, pinned_size = module._pin_state[subset]
_, _, stack_split, pinned_size, *_ = module._pin_state[subset]
size = pin.nbytes
comfy.model_management.ensure_pin_registerable(size)
@ -31,26 +69,30 @@ def pin_memory(module, subset="weights", size=None):
return
pin = get_pin(module, subset)
if pin is not None or pin_state["failed"]:
if pin is not None:
return
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
hostbuf, stack, stack_split, pinned_size, counter, buckets = pin_state[subset]
if size is None:
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
offset = hostbuf.size
registerable_size = size + max(0, hostbuf.size - pinned_size[0])
registerable_size = size
priority = getattr(module, "_pin_balancer_priority", None)
if priority is None:
priority = comfy.utils.bit_reverse_range(counter[0], 16)
counter[0] += 1
module._pin_balancer_priority = priority
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
if (not comfy.model_management.ensure_pin_budget(size) or
not comfy.model_management.ensure_pin_registerable(registerable_size)):
pin_state["failed"] = True
return False
return _steal_pin(module, stack, buckets, size, priority)
try:
hostbuf.extend(size=size)
except RuntimeError:
pin_state["failed"] = True
return False
return _steal_pin(module, stack, buckets, size, priority)
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
@ -60,4 +102,5 @@ def pin_memory(module, subset="weights", size=None):
stack_split[0] = max(stack_split[0], module._pin_stack_index)
comfy.model_management.TOTAL_PINNED_MEMORY += size
pinned_size[0] += size
_add_to_bucket(module, buckets, size, priority)
return True

View File

@ -464,10 +464,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
try:
# TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
# we extend multigpu QA beyond CUDA. Unconditional call crashes on
# XPU/NPU/MPS/CPU/DirectML backends.
torch.cuda.set_device(device)
comfy.model_management.set_torch_device(device)
model_current: BaseModel = model_options["multigpu_clones"][device].model
# run every hooked_to_run separately
with torch.no_grad():

View File

@ -18,6 +18,7 @@ import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.seedvr.vae
import comfy.ldm.triposplat.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.cogvideo.vae
import comfy.ldm.hunyuan_video.vae
@ -943,6 +944,16 @@ class VAE:
#Force cast it for --disable-dynamic-vram users until there is a true core fix.
if not comfy.memory_management.aimdo_enabled:
self.disable_offload = True
elif "gs.base_offset_scale" in sd and "octree.out_proj.weight" in sd: # TripoSplat octree gaussian decoder
self.first_stage_model = comfy.ldm.triposplat.vae.OctreeGaussianDecoder()
self.latent_channels = 16
self.latent_dim = 1
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# The generic VAE.encode/decode path isn't used: VAEDecodeTripoSplat calls the gaussian
# decoder directly (structured GaussianSplat objects, not a tensor and reserves VRAM itself from num_gaussians.
def _no_generic_io(*args, **kwargs):
raise RuntimeError("TripoSplat gaussian decoder: use the 'TripoSplat Decode' (VAEDecodeTripoSplat)")
self.memory_used_encode = self.memory_used_decode = _no_generic_io
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None

View File

@ -1538,6 +1538,30 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
latent_format = latent_formats.Hunyuan3Dv2mini
class TripoSplat(supported_models_base.BASE):
# Image -> 3D gaussian splat flow denoiser
unet_config = {
"image_model": "triposplat",
}
unet_extra_config = {}
sampling_settings = {
"shift": 3.0,
}
memory_usage_factor = 0.6
latent_format = latent_formats.TripoSplat
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
return model_base.TripoSplat(self, device=device)
def clip_target(self, state_dict={}):
return None
class HiDream(supported_models_base.BASE):
unet_config = {
"image_model": "hidream",
@ -2228,6 +2252,7 @@ models = [
Hunyuan3Dv2mini,
Hunyuan3Dv2,
Hunyuan3Dv2_1,
TripoSplat,
HiDream,
HiDreamO1,
Chroma,

View File

@ -85,9 +85,9 @@ _TYPES = {
def load_safetensors(ckpt):
import comfy_aimdo.model_mmap
f = open(ckpt, "rb", buffering=0)
file_lock = threading.Lock()
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
f = model_mmap.get_file_handle()
file_size = os.path.getsize(ckpt)
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
@ -1452,3 +1452,10 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res
return res
def bit_reverse_range(index, bits):
result = 0
for _ in range(bits):
result = (result << 1) | (index & 1)
index >>= 1
return result

View File

@ -5,7 +5,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, SPLAT, File3D
from . import _io_public as io
from . import _ui_public as ui
from comfy_execution.utils import get_executing_context
@ -143,6 +143,7 @@ class Types:
VideoComponents = VideoComponents
MESH = MESH
VOXEL = VOXEL
SPLAT = SPLAT
File3D = File3D

View File

@ -65,6 +65,12 @@ class VideoInput(ABC):
buffer.seek(0)
return buffer
def get_active_trim_window(self) -> tuple[float, float]:
"""Return the active trim as ``(start_time, duration)`` in seconds (start_time normalized
to ``>= 0``; ``duration == 0`` means "until the end"). Default: no trim; trimmable subclasses override.
"""
return 0.0, 0.0
# Provide a default implementation, but subclasses can provide optimized versions
# if possible.
def get_dimensions(self) -> tuple[int, int]:

View File

@ -75,6 +75,12 @@ class VideoFromFile(VideoInput):
self.__file.seek(0)
return self.__file
def get_active_trim_window(self) -> tuple[float, float]:
start_time = self.__start_time
if start_time < 0:
start_time = max(self._get_raw_duration() + start_time, 0.0)
return float(start_time), float(self.__duration)
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.

View File

@ -28,7 +28,7 @@ if TYPE_CHECKING:
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL, SVG as _SVG, File3D
from ._util import MESH, VOXEL, SPLAT, SVG as _SVG, File3D
class FolderType(str, Enum):
@ -684,6 +684,10 @@ class Voxel(ComfyTypeIO):
class Mesh(ComfyTypeIO):
Type = MESH
@comfytype(io_type="SPLAT")
class Splat(ComfyTypeIO):
Type = SPLAT
@comfytype(io_type="FILE_3D")
class File3DAny(ComfyTypeIO):
@ -727,6 +731,30 @@ class File3DUSDZ(ComfyTypeIO):
Type = File3D
@comfytype(io_type="FILE_3D_PLY")
class File3DPLY(ComfyTypeIO):
"""PLY format 3D file - point cloud or Gaussian splat."""
Type = File3D
@comfytype(io_type="FILE_3D_SPLAT")
class File3DSPLAT(ComfyTypeIO):
"""SPLAT format 3D file - 3D Gaussian splat."""
Type = File3D
@comfytype(io_type="FILE_3D_SPZ")
class File3DSPZ(ComfyTypeIO):
"""SPZ format 3D file - compressed 3D Gaussian splat."""
Type = File3D
@comfytype(io_type="FILE_3D_KSPLAT")
class File3DKSPLAT(ComfyTypeIO):
"""KSPLAT format 3D file - 3D Gaussian splat."""
Type = File3D
@comfytype(io_type="HOOKS")
class Hooks(ComfyTypeIO):
if TYPE_CHECKING:
@ -777,6 +805,17 @@ class Load3DCamera(ComfyTypeIO):
Type = CameraInfo
@comfytype(io_type="LOAD3D_MODEL_INFO")
class Load3DModelInfo(ComfyTypeIO):
class Model3DTransform(TypedDict):
# Coordinate system: right-handed, Y-up, world space
position: dict[str, float | int] # scene units
quaternion: dict[str, float | int] # normalized, dimensionless; world rotation
scale: dict[str, float | int] # dimensionless multiplier
Type = list[Model3DTransform]
@comfytype(io_type="LOAD_3D")
class Load3D(ComfyTypeIO):
"""3D models are stored as a dictionary."""
@ -786,6 +825,7 @@ class Load3D(ComfyTypeIO):
normal: str
camera_info: Load3DCamera.CameraInfo
recording: NotRequired[str]
model_3d_info: NotRequired[list[Load3DModelInfo.Model3DTransform]]
Type = Model3DDict
@ -2284,6 +2324,7 @@ __all__ = [
"LossMap",
"Voxel",
"Mesh",
"Splat",
"File3DAny",
"File3DGLB",
"File3DGLTF",
@ -2291,6 +2332,10 @@ __all__ = [
"File3DOBJ",
"File3DSTL",
"File3DUSDZ",
"File3DPLY",
"File3DSPLAT",
"File3DSPZ",
"File3DKSPLAT",
"Hooks",
"HookKeyframes",
"TimestepsRange",
@ -2298,6 +2343,7 @@ __all__ = [
"FlowControl",
"Accumulation",
"Load3DCamera",
"Load3DModelInfo",
"Load3D",
"Load3DAnimation",
"Photomaker",

View File

@ -452,6 +452,16 @@ class PreviewUI3D(_UIOutput):
return {"result": [self.model_file, self.camera_info, self.bg_image_path]}
class PreviewUI3DAdvanced(_UIOutput):
def __init__(self, model_file, camera_info, model_3d_info):
self.model_file = model_file
self.camera_info = camera_info
self.model_3d_info = model_3d_info
def as_dict(self):
return {"result": [self.model_file, self.camera_info, self.model_3d_info]}
class PreviewText(_UIOutput):
def __init__(self, value: str, **kwargs):
self.value = value
@ -471,5 +481,6 @@ __all__ = [
"PreviewAudio",
"PreviewVideo",
"PreviewUI3D",
"PreviewUI3DAdvanced",
"PreviewText",
]

View File

@ -1,5 +1,5 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents
from .geometry_types import VOXEL, MESH, File3D
from .geometry_types import VOXEL, MESH, SPLAT, File3D
from .image_types import SVG
__all__ = [
@ -9,6 +9,7 @@ __all__ = [
"VideoComponents",
"VOXEL",
"MESH",
"SPLAT",
"File3D",
"SVG",
]

View File

@ -11,13 +11,32 @@ class VOXEL:
self.data = data
class SPLAT:
"""A batch of 3D Gaussian splats in render-ready (activated, world-space) form.
Tensors are (B, N, ...) and zero-padded to a common N across the batch; `counts` (B,) holds the
real per-item lengths (None when rows are uniform and no slicing is needed). SH coefficients are
stored as (B, N, K, 3) with K = (sh_degree + 1)**2; the DC (diffuse) term is sh[..., 0, :].
"""
def __init__(self, positions: torch.Tensor, scales: torch.Tensor, rotations: torch.Tensor,
opacities: torch.Tensor, sh: torch.Tensor, counts: torch.Tensor | None = None):
self.positions = positions # (B, N, 3) world-space centers
self.scales = scales # (B, N, 3) linear (positive) per-axis std
self.rotations = rotations # (B, N, 4) quaternion wxyz (normalized)
self.opacities = opacities # (B, N, 1) in [0, 1]
self.sh = sh # (B, N, K, 3) spherical-harmonic color coefficients
self.counts = counts # (B,) real lengths, or None
class MESH:
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor,
uvs: torch.Tensor | None = None,
vertex_colors: torch.Tensor | None = None,
texture: torch.Tensor | None = None,
vertex_counts: torch.Tensor | None = None,
face_counts: torch.Tensor | None = None):
face_counts: torch.Tensor | None = None,
unlit: bool = False):
assert (vertex_counts is None) == (face_counts is None), \
"vertex_counts and face_counts must be provided together (both or neither)"
@ -30,6 +49,8 @@ class MESH:
# these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed.
self.vertex_counts = vertex_counts
self.face_counts = face_counts
# Render flat / emissive (no scene lighting) when saved, e.g. for gaussian-splat-derived meshes.
self.unlit = unlit
class File3D:

View File

@ -1,71 +1,71 @@
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field, confloat, conint
class BFLOutputFormat(str, Enum):
png = 'png'
jpeg = 'jpeg'
from pydantic import BaseModel, Field
class BFLFluxExpandImageRequest(BaseModel):
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image')
bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image')
left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image')
right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand')
prompt: str = Field(...)
prompt_upsampling: bool | None = Field(None)
seed: int | None = Field(None)
top: int = Field(...)
bottom: int = Field(...)
left: int = Field(...)
right: int = Field(...)
steps: int = Field(...)
guidance: float = Field(...)
safety_tolerance: int = Field(6)
output_format: str = Field("png")
image: str = Field(None, description="A Base64-encoded string representing the image you wish to expand")
class BFLFluxFillImageRequest(BaseModel):
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
prompt: str = Field(...)
prompt_upsampling: bool | None = Field(None)
seed: int | None = Field(None)
steps: int = Field(...)
guidance: float = Field(...)
safety_tolerance: int = Field(6)
output_format: str = Field("png")
image: str = Field(
None, description="Base64-encoded string representing the image to modify. Can contain alpha mask if desired.",
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
mask: str = Field(
None, description="Base64-encoded string representing the mask of the areas you wish to modify."
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
class BFLFluxEraseRequest(BaseModel):
image: str = Field(..., description="A Base64-encoded string representing the image to erase from.")
mask: str = Field(
...,
description="A Base64-encoded black/white mask matching the input dimensions; "
"white (255) marks areas to remove, black (0) marks areas to preserve.",
)
image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.')
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
dilate_pixels: int = Field(10)
output_format: str = Field("png")
class BFLFluxVTORequest(BaseModel):
prompt: str = Field(
..., description="Natural-language styling instruction. Required field, but may be an empty string."
)
person: str = Field(..., description="A Base64-encoded string representing the person image.")
garment: str = Field(..., description="A Base64-encoded string representing the garment reference image.")
seed: int | None = Field(None)
safety_tolerance: int = Field(5)
output_format: str = Field("png")
class BFLFluxProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.')
height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
# image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
# None, description='Blend between the prompt and the image prompt.'
# )
prompt: str = Field(...)
prompt_upsampling: bool | None = Field(None)
seed: int | None = Field(None)
width: int = Field(1024, description="Must be a multiple of 32.")
height: int = Field(768, description="Must be a multiple of 32.")
safety_tolerance: int = Field(6)
output_format: str = Field("png")
image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
class Flux2ProGenerateRequest(BaseModel):
@ -83,55 +83,37 @@ class Flux2ProGenerateRequest(BaseModel):
input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
safety_tolerance: int | None = Field(
5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5
)
output_format: str | None = Field(
"png", description="Output format for the generated image. Can be 'jpeg' or 'png'."
)
safety_tolerance: int = Field(5)
output_format: str = Field("png")
class BFLFluxKontextProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process')
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=2)] = Field(
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
prompt: str = Field(...)
input_image: str | None = Field(None, description="Image to edit in base64 format")
seed: int | None = Field(None)
guidance: float = Field(...)
steps: int = Field(...)
safety_tolerance: int = Field(2)
output_format: str = Field("png")
aspect_ratio: str | None = Field(None)
prompt_upsampling: bool | None = Field(None)
class BFLFluxProUltraGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.')
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
None, description='Blend between the prompt and the image prompt.'
)
prompt: str = Field(...)
prompt_upsampling: bool | None = Field(None)
seed: int | None = Field(None)
aspect_ratio: str | None = Field(None)
safety_tolerance: int = Field(6)
output_format: str = Field("png")
raw: bool | None = Field(None)
image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
image_prompt_strength: float | None = Field(None)
class BFLFluxProGenerateResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.")
polling_url: str = Field(..., description="URL to poll for the generation result.")
id: str = Field(...)
polling_url: str = Field(...)
cost: float | None = Field(None, description="Price in cents")
@ -145,7 +127,7 @@ class BFLStatus(str, Enum):
class BFLFluxStatusResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.")
status: BFLStatus = Field(..., description="The status of the task.")
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)
id: str = Field(...)
status: BFLStatus = Field(...)
result: dict[str, Any] | None = Field(None)
progress: float | None = Field(None, ge=0.0, le=1.0)

View File

@ -1,25 +1,25 @@
from enum import Enum
from typing import Optional, Any
from typing import Any
from pydantic import BaseModel, Field, RootModel
class TripoModelVersion(str, Enum):
v3_1_20260211 = 'v3.1-20260211'
v3_0_20250812 = 'v3.0-20250812'
v2_5_20250123 = 'v2.5-20250123'
v2_0_20240919 = 'v2.0-20240919'
v1_4_20240625 = 'v1.4-20240625'
v3_1_20260211 = "v3.1-20260211"
v3_0_20250812 = "v3.0-20250812"
v2_5_20250123 = "v2.5-20250123"
v2_0_20240919 = "v2.0-20240919"
v1_4_20240625 = "v1.4-20240625"
class TripoGeometryQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
standard = "standard"
detailed = "detailed"
class TripoTextureQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
standard = "standard"
detailed = "detailed"
class TripoStyle(str, Enum):
@ -33,6 +33,7 @@ class TripoStyle(str, Enum):
ANCIENT_BRONZE = "ancient_bronze"
NONE = "None"
class TripoTaskType(str, Enum):
TEXT_TO_MODEL = "text_to_model"
IMAGE_TO_MODEL = "image_to_model"
@ -45,26 +46,27 @@ class TripoTaskType(str, Enum):
STYLIZE_MODEL = "stylize_model"
CONVERT_MODEL = "convert_model"
class TripoTextureAlignment(str, Enum):
ORIGINAL_IMAGE = "original_image"
GEOMETRY = "geometry"
class TripoOrientation(str, Enum):
ALIGN_IMAGE = "align_image"
DEFAULT = "default"
class TripoOutFormat(str, Enum):
GLB = "glb"
FBX = "fbx"
class TripoTopology(str, Enum):
BIP = "bip"
QUAD = "quad"
class TripoSpec(str, Enum):
MIXAMO = "mixamo"
TRIPO = "tripo"
class TripoAnimation(str, Enum):
IDLE = "preset:idle"
WALK = "preset:walk"
@ -83,11 +85,6 @@ class TripoAnimation(str, Enum):
SERPENTINE_MARCH = "preset:serpentine:march"
AQUATIC_MARCH = "preset:aquatic:march"
class TripoStylizeStyle(str, Enum):
LEGO = "lego"
VOXEL = "voxel"
VORONOI = "voronoi"
MINECRAFT = "minecraft"
class TripoConvertFormat(str, Enum):
GLTF = "GLTF"
@ -97,6 +94,7 @@ class TripoConvertFormat(str, Enum):
STL = "STL"
_3MF = "3MF"
class TripoTextureFormat(str, Enum):
BMP = "BMP"
DPX = "DPX"
@ -108,6 +106,7 @@ class TripoTextureFormat(str, Enum):
TIFF = "TIFF"
WEBP = "WEBP"
class TripoTaskStatus(str, Enum):
QUEUED = "queued"
RUNNING = "running"
@ -118,183 +117,223 @@ class TripoTaskStatus(str, Enum):
BANNED = "banned"
EXPIRED = "expired"
class TripoFbxPreset(str, Enum):
BLENDER = "blender"
MIXAMO = "mixamo"
_3DSMAX = "3dsmax"
class TripoFileTokenReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference')
type: str | None = Field(None, description="The type of the reference")
file_token: str
class TripoUrlReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference')
type: str | None = Field(None, description="The type of the reference")
url: str
class TripoObjectStorage(BaseModel):
bucket: str
key: str
class TripoObjectReference(BaseModel):
type: str
object: TripoObjectStorage
class TripoFileEmptyReference(BaseModel):
pass
class TripoFileReference(RootModel):
root: TripoFileTokenReference | TripoUrlReference | TripoObjectReference | TripoFileEmptyReference
class TripoGetStsTokenRequest(BaseModel):
format: str = Field(..., description='The format of the image')
class TripoTextToModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task')
prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024)
negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024)
model_version: Optional[TripoModelVersion] = TripoModelVersion.v2_5_20250123
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
image_seed: Optional[int] = Field(None, description='The seed for the text')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
style: Optional[TripoStyle] = None
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description="Type of task")
prompt: str = Field(..., description="The text prompt describing the model to generate", max_length=1024)
negative_prompt: str | None = Field(None, description="The negative text prompt", max_length=1024)
model_version: TripoModelVersion | None = TripoModelVersion.v2_5_20250123
face_limit: int | None = Field(None, description="The number of faces to limit the generation to")
texture: bool | None = Field(True, description="Whether to apply texture to the generated model")
pbr: bool | None = Field(True, description="Whether to apply PBR to the generated model")
image_seed: int | None = Field(None, description="The seed for the text")
model_seed: int | None = Field(None, description="The seed for the model")
texture_seed: int | None = Field(None, description="The seed for the texture")
texture_quality: TripoTextureQuality | None = TripoTextureQuality.standard
geometry_quality: TripoGeometryQuality | None = TripoGeometryQuality.standard
style: TripoStyle | None = None
auto_size: bool | None = Field(False, description="Whether to auto-size the model")
quad: bool | None = Field(False, description="Whether to apply quad to the generated model")
class TripoImageToModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description='Type of task')
file: TripoFileReference = Field(..., description='The file reference to convert to a model')
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = TripoOrientation.DEFAULT
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description="Type of task")
file: TripoFileReference = Field(..., description="The file reference to convert to a model")
model_version: TripoModelVersion | None = Field(None, description="The model version to use for generation")
face_limit: int | None = Field(None, description="The number of faces to limit the generation to")
texture: bool | None = Field(True, description="Whether to apply texture to the generated model")
pbr: bool | None = Field(True, description="Whether to apply PBR to the generated model")
model_seed: int | None = Field(None, description="The seed for the model")
texture_seed: int | None = Field(None, description="The seed for the texture")
texture_quality: TripoTextureQuality | None = TripoTextureQuality.standard
geometry_quality: TripoGeometryQuality | None = TripoGeometryQuality.standard
texture_alignment: TripoTextureAlignment | None = Field(
TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method"
)
style: TripoStyle | None = Field(None, description="The style to apply to the generated model")
auto_size: bool | None = Field(False, description="Whether to auto-size the model")
orientation: TripoOrientation | None = TripoOrientation.DEFAULT
quad: bool | None = Field(False, description="Whether to apply quad to the generated model")
class TripoMultiviewToModelRequest(BaseModel):
type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL
files: list[TripoFileReference] = Field(..., description='The file references to convert to a model')
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
files: list[TripoFileReference] = Field(..., description="The file references to convert to a model")
model_version: TripoModelVersion | None = Field(None, description="The model version to use for generation")
orthographic_projection: bool | None = Field(False, description="Whether to use orthographic projection")
face_limit: int | None = Field(None, description="The number of faces to limit the generation to")
texture: bool | None = Field(True, description="Whether to apply texture to the generated model")
pbr: bool | None = Field(True, description="Whether to apply PBR to the generated model")
model_seed: int | None = Field(None, description="The seed for the model")
texture_seed: int | None = Field(None, description="The seed for the texture")
texture_quality: TripoTextureQuality | None = TripoTextureQuality.standard
geometry_quality: TripoGeometryQuality | None = TripoGeometryQuality.standard
texture_alignment: TripoTextureAlignment | None = TripoTextureAlignment.ORIGINAL_IMAGE
auto_size: bool | None = Field(False, description="Whether to auto-size the model")
orientation: TripoOrientation | None = Field(TripoOrientation.DEFAULT, description="The orientation for the model")
quad: bool | None = Field(False, description="Whether to apply quad to the generated model")
class TripoTextureModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the model')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = Field(None, description='The quality of the texture')
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description="Type of task")
original_model_task_id: str = Field(..., description="The task ID of the original model")
texture: bool | None = Field(True, description="Whether to apply texture to the model")
pbr: bool | None = Field(True, description="Whether to apply PBR to the model")
model_seed: int | None = Field(None, description="The seed for the model")
texture_seed: int | None = Field(None, description="The seed for the texture")
texture_quality: TripoTextureQuality | None = Field(None, description="The quality of the texture")
texture_alignment: TripoTextureAlignment | None = Field(
TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method"
)
class TripoRefineModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description='Type of task')
draft_model_task_id: str = Field(..., description='The task ID of the draft model')
type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description="Type of task")
draft_model_task_id: str = Field(..., description="The task ID of the draft model")
class TripoAnimatePrerigcheckRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.ANIMATE_PRERIGCHECK, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
class TripoAnimateRigRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
spec: Optional[TripoSpec] = Field(TripoSpec.TRIPO, description='The specification for rigging')
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description="Type of task")
original_model_task_id: str = Field(..., description="The task ID of the original model")
out_format: TripoOutFormat | None = Field(TripoOutFormat.GLB, description="The output format")
spec: TripoSpec | None = Field(TripoSpec.TRIPO, description="The specification for rigging")
class TripoAnimateRetargetRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
animation: TripoAnimation = Field(..., description='The animation to apply')
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
bake_animation: Optional[bool] = Field(True, description='Whether to bake the animation')
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description="Type of task")
original_model_task_id: str = Field(..., description="The task ID of the original model")
animation: TripoAnimation = Field(..., description="The animation to apply")
out_format: TripoOutFormat | None = Field(TripoOutFormat.GLB, description="The output format")
bake_animation: bool | None = Field(True, description="Whether to bake the animation")
class TripoStylizeModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.STYLIZE_MODEL, description='Type of task')
style: TripoStylizeStyle = Field(..., description='The style to apply to the model')
original_model_task_id: str = Field(..., description='The task ID of the original model')
block_size: Optional[int] = Field(80, description='The block size for stylization')
class TripoConvertModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
format: TripoConvertFormat = Field(..., description='The format to convert to')
original_model_task_id: str = Field(..., description='The task ID of the original model')
quad: Optional[bool] = Field(None, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(None, description='The size of the texture')
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom')
scale_factor: Optional[float] = Field(None, description='The scale factor for the model')
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
bake: Optional[bool] = Field(None, description='Whether to bake the model')
part_names: Optional[list[str]] = Field(None, description='The names of the parts to include')
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description="Type of task")
format: TripoConvertFormat = Field(..., description="The format to convert to")
original_model_task_id: str = Field(..., description="The task ID of the original model")
quad: bool | None = Field(None, description="Whether to apply quad to the model")
force_symmetry: bool | None = Field(None, description="Whether to force symmetry")
face_limit: int | None = Field(None, description="The number of faces to limit the conversion to")
flatten_bottom: bool | None = Field(None, description="Whether to flatten the bottom of the model")
flatten_bottom_threshold: float | None = Field(None, description="The threshold for flattening the bottom")
texture_size: int | None = Field(None, description="The size of the texture")
texture_format: TripoTextureFormat | None = Field(TripoTextureFormat.JPEG, description="The format of the texture")
pivot_to_center_bottom: bool | None = Field(None, description="Whether to pivot to the center bottom")
scale_factor: float | None = Field(None, description="The scale factor for the model")
with_animation: bool | None = Field(None, description="Whether to include animations")
pack_uv: bool | None = Field(None, description="Whether to pack the UVs")
bake: bool | None = Field(None, description="Whether to bake the model")
part_names: list[str] | None = Field(None, description="The names of the parts to include")
fbx_preset: TripoFbxPreset | None = Field(None, description="The preset for the FBX export")
export_vertex_colors: bool | None = Field(None, description="Whether to export the vertex colors")
export_orientation: TripoOrientation | None = Field(None, description="The orientation for the export")
animate_in_place: bool | None = Field(None, description="Whether to animate in place")
class TripoP1CommonRequest(BaseModel):
"""Fields supported by Tripo P1 across all input types."""
model_version: str = Field("P1-20260311")
model_seed: int | None = Field(None, description="Random seed for geometry generation")
face_limit: int | None = Field(None, ge=48, le=20000, description="Target face count (48-20000)")
texture: bool | None = Field(None, description="Enable texturing; pbr=True forces this true")
pbr: bool | None = Field(None, description="Enable PBR maps; when true, texture is also enabled")
texture_seed: int | None = Field(None, description="Random seed for texture generation")
texture_quality: str | None = Field(None, description='"standard" or "detailed"')
auto_size: bool | None = Field(None, description="Scale to real-world meters")
compress: str | None = Field(None, description='Only "geometry" is supported')
export_uv: bool | None = Field(None, description="Perform UV unwrapping during generation")
class TripoP1TextToModelRequest(TripoP1CommonRequest):
type: str = "text_to_model"
prompt: str = Field(..., max_length=1024)
negative_prompt: str | None = Field(None, max_length=255)
image_seed: int | None = None
class TripoP1ImageToModelRequest(TripoP1CommonRequest):
type: str = "image_to_model"
file: TripoFileReference
enable_image_autofix: bool | None = None
texture_alignment: str | None = Field(None, description='"original_image" or "geometry"')
orientation: str | None = Field(None, description='"default" or "align_image"; needs texture=true')
class TripoP1MultiviewToModelRequest(TripoP1CommonRequest):
"""P1 multiview generation.
Tripo requires `files` to be exactly four entries in [front, left, back, right] order with `{}`
(TripoFileEmptyReference) for omitted slots; front is required and at least two images total must be provided.
"""
type: str = "multiview_to_model"
files: list[TripoFileReference]
texture_alignment: str | None = None
orientation: str | None = None
class TripoTaskOutput(BaseModel):
model: Optional[str] = Field(None, description='URL to the model')
base_model: Optional[str] = Field(None, description='URL to the base model')
pbr_model: Optional[str] = Field(None, description='URL to the PBR model')
rendered_image: Optional[str] = Field(None, description='URL to the rendered image')
riggable: Optional[bool] = Field(None, description='Whether the model is riggable')
model: str | None = Field(None, description="URL to the model")
base_model: str | None = Field(None, description="URL to the base model")
pbr_model: str | None = Field(None, description="URL to the PBR model")
rendered_image: str | None = Field(None, description="URL to the rendered image")
riggable: bool | None = Field(None, description="Whether the model is riggable")
class TripoTask(BaseModel):
task_id: str = Field(..., description='The task ID')
type: Optional[str] = Field(None, description='The type of task')
status: Optional[TripoTaskStatus] = Field(None, description='The status of the task')
input: Optional[dict[str, Any]] = Field(None, description='The input parameters for the task')
output: Optional[TripoTaskOutput] = Field(None, description='The output of the task')
progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100)
create_time: Optional[int] = Field(None, description='The creation time of the task')
running_left_time: Optional[int] = Field(None, description='The estimated time left for the task')
queue_position: Optional[int] = Field(None, description='The position in the queue')
task_id: str = Field(..., description="The task ID")
type: str | None = Field(None, description="The type of task")
status: TripoTaskStatus | None = Field(None, description="The status of the task")
input: dict[str, Any] | None = Field(None, description="The input parameters for the task")
output: TripoTaskOutput | None = Field(None, description="The output of the task")
progress: int | None = Field(None, description="The progress of the task", ge=0, le=100)
create_time: int | None = Field(None, description="The creation time of the task")
running_left_time: int | None = Field(None, description="The estimated time left for the task")
queue_position: int | None = Field(None, description="The position in the queue")
consumed_credit: int | None = Field(None)
class TripoTaskResponse(BaseModel):
code: int = Field(0, description='The response code')
data: TripoTask = Field(..., description='The task data')
code: int = Field(0, description="The response code")
data: TripoTask = Field(..., description="The task data")
class TripoGeneralResponse(BaseModel):
code: int = Field(0, description='The response code')
data: dict[str, str] = Field(..., description='The task ID data')
class TripoBalanceData(BaseModel):
balance: float = Field(..., description='The account balance')
frozen: float = Field(..., description='The frozen balance')
class TripoBalanceResponse(BaseModel):
code: int = Field(0, description='The response code')
data: TripoBalanceData = Field(..., description='The balance data')
class TripoErrorResponse(BaseModel):
code: int = Field(..., description='The error code')
message: str = Field(..., description='The error message')
suggestion: str = Field(..., description='The suggestion for fixing the error')
code: int = Field(..., description="The error code")
message: str = Field(..., description="The error message")
suggestion: str = Field(..., description="The suggestion for fixing the error")

View File

@ -155,7 +155,7 @@ class ClaudeNode(IO.ComfyNode):
return IO.Schema(
node_id="ClaudeNode",
display_name="Anthropic Claude",
category="text/partner/Anthropic",
category="partner/text/Anthropic",
essentials_category="Text Generation",
description="Generate text responses with Anthropic's Claude models. "
"Provide a text prompt and optionally one or more images for multimodal context.",

View File

@ -206,7 +206,7 @@ class BeebleSwitchXVideoEdit(IO.ComfyNode):
return IO.Schema(
node_id="BeebleSwitchXVideoEdit",
display_name="Beeble SwitchX Video Edit",
category="video/partner/Beeble",
category="partner/video/Beeble",
description=(
"Edit a video with Beeble SwitchX. Switches anything in the scene (background, "
"lighting, costume) while preserving the original subject's pixels and motion. "
@ -302,7 +302,7 @@ class BeebleSwitchXImageEdit(IO.ComfyNode):
return IO.Schema(
node_id="BeebleSwitchXImageEdit",
display_name="Beeble SwitchX Image Edit",
category="image/partner/Beeble",
category="partner/image/Beeble",
description=(
"Edit a single image with Beeble SwitchX. Switches anything in the scene "
"(background, lighting, costume) while preserving the original subject's pixels. "

View File

@ -4,17 +4,20 @@ from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bfl import (
BFLFluxEraseRequest,
BFLFluxExpandImageRequest,
BFLFluxFillImageRequest,
BFLFluxKontextProGenerateRequest,
BFLFluxProGenerateResponse,
BFLFluxProUltraGenerateRequest,
BFLFluxStatusResponse,
BFLFluxVTORequest,
BFLStatus,
Flux2ProGenerateRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
convert_mask_to_image,
download_url_to_image_tensor,
get_number_of_images,
poll_op,
@ -22,19 +25,11 @@ from comfy_api_nodes.util import (
sync_op,
tensor_to_base64_string,
validate_aspect_ratio_string,
validate_image_dimensions,
validate_string,
)
def convert_mask_to_image(mask: Input.Image):
"""
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
"""
mask = mask.unsqueeze(-1)
mask = torch.cat([mask] * 3, dim=-1)
return mask
class FluxProUltraImageNode(IO.ComfyNode):
@classmethod
@ -42,7 +37,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
return IO.Schema(
node_id="FluxProUltraImageNode",
display_name="Flux 1.1 [pro] Ultra Image",
category="image/partner/BFL",
category="partner/image/BFL",
description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
inputs=[
IO.String.Input(
@ -160,7 +155,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
return IO.Schema(
node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME,
category="image/partner/BFL",
category="partner/image/BFL",
description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
inputs=[
IO.String.Input(
@ -282,7 +277,7 @@ class FluxProExpandNode(IO.ComfyNode):
return IO.Schema(
node_id="FluxProExpandNode",
display_name="Flux.1 Expand Image",
category="image/partner/BFL",
category="partner/image/BFL",
description="Outpaints image based on prompt.",
inputs=[
IO.Image.Input("image"),
@ -419,7 +414,7 @@ class FluxProFillNode(IO.ComfyNode):
return IO.Schema(
node_id="FluxProFillNode",
display_name="Flux.1 Fill Image",
category="image/partner/BFL",
category="partner/image/BFL",
description="Inpaints image based on mask and prompt.",
inputs=[
IO.Image.Input("image"),
@ -519,6 +514,163 @@ class FluxProFillNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxEraseNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxEraseNode",
display_name="Flux Erase Image",
category="partner/image/BFL",
description="Removes the masked object from an image and reconstructs the background. "
"Paint the mask over what you want to erase.",
inputs=[
IO.Image.Input("image"),
IO.Mask.Input("mask", tooltip="White areas are removed; black areas are preserved."),
IO.Int.Input(
"dilate_pixels",
default=10,
min=0,
max=25,
tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"range_usd","min_usd":0.03,"max_usd":0.06,"format":{"approximate":true}}""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
mask: Input.Image,
dilate_pixels: int = 10,
) -> IO.NodeOutput:
validate_image_dimensions(image, min_width=256, min_height=256)
mask = resize_mask_to_image(mask, image)
mask = tensor_to_base64_string(convert_mask_to_image(mask))
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/erase-v1", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxEraseRequest(
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
mask=mask,
dilate_pixels=dilate_pixels,
),
)
def price_extractor(_r: BaseModel) -> float | None:
return None if initial_response.cost is None else initial_response.cost / 100
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
price_extractor=price_extractor,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxVTONode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxVTONode",
display_name="Flux Virtual Try-On",
category="partner/image/BFL",
description="Virtual try-on: dresses the person in the provided garment.",
inputs=[
IO.Image.Input("person", tooltip="Image of the person to dress."),
IO.Image.Input("garment", tooltip="Image of the garment to apply."),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Optional natural-language styling instruction (e.g. how the garment should fit).",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"range_usd","min_usd":0.0375,"max_usd":0.075,"format":{"approximate":true}}""",
),
)
@classmethod
async def execute(
cls,
person: Input.Image,
garment: Input.Image,
prompt: str = "",
seed: int = 0,
) -> IO.NodeOutput:
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/vto-v1", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxVTORequest(
prompt=prompt,
person=tensor_to_base64_string(person[:, :, :, :3]),
garment=tensor_to_base64_string(garment[:, :, :, :3]),
seed=seed,
),
)
def price_extractor(_r: BaseModel) -> float | None:
return None if initial_response.cost is None else initial_response.cost / 100
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
price_extractor=price_extractor,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class Flux2ProImageNode(IO.ComfyNode):
NODE_ID = "Flux2ProImageNode"
@ -545,7 +697,7 @@ class Flux2ProImageNode(IO.ComfyNode):
return IO.Schema(
node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME,
category="image/partner/BFL",
category="partner/image/BFL",
description="Generates images synchronously based on prompt and resolution.",
inputs=[
IO.String.Input(
@ -716,7 +868,7 @@ class Flux2ImageNode(IO.ComfyNode):
return IO.Schema(
node_id="Flux2ImageNode",
display_name="Flux.2 Image",
category="image/partner/BFL",
category="partner/image/BFL",
description="Generate images via Flux.2 [pro] or Flux.2 [max] from a prompt and optional reference images.",
inputs=[
IO.String.Input(
@ -853,6 +1005,8 @@ class BFLExtension(ComfyExtension):
FluxKontextMaxImageNode,
FluxProExpandNode,
FluxProFillNode,
FluxEraseNode,
FluxVTONode,
Flux2ProImageNode,
Flux2MaxImageNode,
Flux2ImageNode,

View File

@ -31,7 +31,7 @@ class BriaImageEditNode(IO.ComfyNode):
return IO.Schema(
node_id="BriaImageEditNode",
display_name="Bria FIBO Image Edit",
category="image/partner/Bria",
category="partner/image/Bria",
description="Edit images using Bria latest model",
inputs=[
IO.Combo.Input("model", options=["FIBO"]),
@ -169,7 +169,7 @@ class BriaRemoveImageBackground(IO.ComfyNode):
return IO.Schema(
node_id="BriaRemoveImageBackground",
display_name="Bria Remove Image Background",
category="image/partner/Bria",
category="partner/image/Bria",
description="Remove the background from an image using Bria RMBG 2.0.",
inputs=[
IO.Image.Input("image"),
@ -245,7 +245,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
return IO.Schema(
node_id="BriaRemoveVideoBackground",
display_name="Bria Remove Video Background",
category="video/partner/Bria",
category="partner/video/Bria",
description="Remove the background from a video using Bria. ",
inputs=[
IO.Video.Input("video"),

View File

@ -368,7 +368,7 @@ class ByteDanceImageNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceImageNode",
display_name="ByteDance Image",
category="image/partner/ByteDance",
category="partner/image/ByteDance",
description="Generate images using ByteDance models via api based on prompt",
inputs=[
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
@ -492,7 +492,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceSeedreamNode",
display_name="ByteDance Seedream 4.5 & 5.0",
category="image/partner/ByteDance",
category="partner/image/ByteDance",
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
inputs=[
IO.Combo.Input(
@ -754,7 +754,7 @@ class ByteDanceSeedreamNodeV2(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceSeedreamNodeV2",
display_name="ByteDance Seedream 4.5 & 5.0",
category="image/partner/ByteDance",
category="partner/image/ByteDance",
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
inputs=[
IO.String.Input(
@ -920,7 +920,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceTextToVideoNode",
display_name="ByteDance Text to Video",
category="video/partner/ByteDance",
category="partner/video/ByteDance",
description="Generate video using ByteDance models via api based on prompt",
inputs=[
IO.Combo.Input(
@ -1048,7 +1048,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceImageToVideoNode",
display_name="ByteDance Image to Video",
category="video/partner/ByteDance",
category="partner/video/ByteDance",
description="Generate video using ByteDance models via api based on image and prompt",
inputs=[
IO.Combo.Input(
@ -1185,7 +1185,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceFirstLastFrameNode",
display_name="ByteDance First-Last-Frame to Video",
category="video/partner/ByteDance",
category="partner/video/ByteDance",
description="Generate video using prompt and first and last frames.",
inputs=[
IO.Combo.Input(
@ -1333,7 +1333,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceImageReferenceNode",
display_name="ByteDance Reference Images to Video",
category="video/partner/ByteDance",
category="partner/video/ByteDance",
description="Generate video using prompt and reference images.",
inputs=[
IO.Combo.Input(
@ -1576,7 +1576,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDance2TextToVideoNode",
display_name="ByteDance Seedance 2.0 Text to Video",
category="video/partner/ByteDance",
category="partner/video/ByteDance",
description="Generate video using Seedance 2.0 models based on a text prompt.",
inputs=[
IO.DynamicCombo.Input(
@ -1677,7 +1677,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDance2FirstLastFrameNode",
display_name="ByteDance Seedance 2.0 First-Last-Frame to Video",
category="video/partner/ByteDance",
category="partner/video/ByteDance",
description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.",
inputs=[
IO.DynamicCombo.Input(
@ -1944,7 +1944,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDance2ReferenceNode",
display_name="ByteDance Seedance 2.0 Reference to Video",
category="video/partner/ByteDance",
category="partner/video/ByteDance",
description="Generate, edit, or extend video using Seedance 2.0 with reference images, "
"videos, and audio. Supports multimodal reference, video editing, and video extension.",
inputs=[
@ -2241,7 +2241,7 @@ class ByteDanceCreateImageAsset(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceCreateImageAsset",
display_name="ByteDance Create Image Asset",
category="image/partner/ByteDance",
category="partner/image/ByteDance",
description=(
"Create a Seedance 2.0 personal image asset. Uploads the input image and "
"registers it in the given asset group. If group_id is empty, runs a real-person "
@ -2308,7 +2308,7 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceCreateVideoAsset",
display_name="ByteDance Create Video Asset",
category="video/partner/ByteDance",
category="partner/video/ByteDance",
description=(
"Create a Seedance 2.0 personal video asset. Uploads the input video and "
"registers it in the given asset group. If group_id is empty, runs a real-person "

View File

@ -144,7 +144,7 @@ class ByteDanceSeedNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceSeedNode",
display_name="ByteDance Seed",
category="text/partner/ByteDance",
category="partner/text/ByteDance",
essentials_category="Text Generation",
description="Generate text responses with ByteDance's Seed 2.0 models. "
"Provide a text prompt and optionally one or more images or videos for multimodal context.",

View File

@ -69,7 +69,7 @@ class ElevenLabsSpeechToText(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsSpeechToText",
display_name="ElevenLabs Speech to Text",
category="audio/partner/ElevenLabs",
category="partner/audio/ElevenLabs",
description="Transcribe audio to text. "
"Supports automatic language detection, speaker diarization, and audio event tagging.",
inputs=[
@ -210,7 +210,7 @@ class ElevenLabsVoiceSelector(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsVoiceSelector",
display_name="ElevenLabs Voice Selector",
category="audio/partner/ElevenLabs",
category="partner/audio/ElevenLabs",
description="Select a predefined ElevenLabs voice for text-to-speech generation.",
inputs=[
IO.Combo.Input(
@ -239,7 +239,7 @@ class ElevenLabsTextToSpeech(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsTextToSpeech",
display_name="ElevenLabs Text to Speech",
category="audio/partner/ElevenLabs",
category="partner/audio/ElevenLabs",
description="Convert text to speech.",
inputs=[
IO.Custom(ELEVENLABS_VOICE).Input(
@ -414,7 +414,7 @@ class ElevenLabsAudioIsolation(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsAudioIsolation",
display_name="ElevenLabs Voice Isolation",
category="audio/partner/ElevenLabs",
category="partner/audio/ElevenLabs",
description="Remove background noise from audio, isolating vocals or speech.",
inputs=[
IO.Audio.Input(
@ -459,7 +459,7 @@ class ElevenLabsTextToSoundEffects(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsTextToSoundEffects",
display_name="ElevenLabs Text to Sound Effects",
category="audio/partner/ElevenLabs",
category="partner/audio/ElevenLabs",
description="Generate sound effects from text descriptions.",
inputs=[
IO.String.Input(
@ -555,7 +555,7 @@ class ElevenLabsInstantVoiceClone(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsInstantVoiceClone",
display_name="ElevenLabs Instant Voice Clone",
category="audio/partner/ElevenLabs",
category="partner/audio/ElevenLabs",
description="Create a cloned voice from audio samples. "
"Provide 1-8 audio recordings of the voice to clone.",
inputs=[
@ -658,7 +658,7 @@ class ElevenLabsSpeechToSpeech(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsSpeechToSpeech",
display_name="ElevenLabs Speech to Speech",
category="audio/partner/ElevenLabs",
category="partner/audio/ElevenLabs",
description="Transform speech from one voice to another while preserving the original content and emotion.",
inputs=[
IO.Custom(ELEVENLABS_VOICE).Input(
@ -793,7 +793,7 @@ class ElevenLabsTextToDialogue(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsTextToDialogue",
display_name="ElevenLabs Text to Dialogue",
category="audio/partner/ElevenLabs",
category="partner/audio/ElevenLabs",
description="Generate multi-speaker dialogue from text. Each dialogue entry has its own text and voice.",
inputs=[
IO.Float.Input(

View File

@ -300,7 +300,7 @@ class GeminiNode(IO.ComfyNode):
return IO.Schema(
node_id="GeminiNode",
display_name="Google Gemini",
category="text/partner/Gemini",
category="partner/text/Gemini",
description="Generate text responses with Google's Gemini AI model. "
"You can provide multiple types of inputs (text, images, audio, video) "
"as context for generating more relevant and meaningful responses.",
@ -541,7 +541,7 @@ class GeminiInputFiles(IO.ComfyNode):
return IO.Schema(
node_id="GeminiInputFiles",
display_name="Gemini Input Files",
category="text/partner/Gemini",
category="partner/text/Gemini",
description="Loads and prepares input files to include as inputs for Gemini LLM nodes. "
"The files will be read by the Gemini model when generating a response. "
"The contents of the text file count toward the token limit. "
@ -598,7 +598,7 @@ class GeminiImage(IO.ComfyNode):
return IO.Schema(
node_id="GeminiImageNode",
display_name="Nano Banana (Google Gemini Image)",
category="image/partner/Gemini",
category="partner/image/Gemini",
description="Edit images synchronously via Google API.",
inputs=[
IO.String.Input(
@ -731,7 +731,7 @@ class GeminiImage2(IO.ComfyNode):
return IO.Schema(
node_id="GeminiImage2Node",
display_name="Nano Banana Pro (Google Gemini Image)",
category="image/partner/Gemini",
category="partner/image/Gemini",
description="Generate or edit images synchronously via Google Vertex API.",
inputs=[
IO.String.Input(
@ -869,7 +869,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
return IO.Schema(
node_id="GeminiNanoBanana2",
display_name="Nano Banana 2",
category="image/partner/Gemini",
category="partner/image/Gemini",
description="Generate or edit images synchronously via Google Vertex API.",
inputs=[
IO.String.Input(
@ -1085,7 +1085,7 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
return IO.Schema(
node_id="GeminiNanoBanana2V2",
display_name="Nano Banana 2",
category="image/partner/Gemini",
category="partner/image/Gemini",
description="Generate or edit images synchronously via Google Vertex API.",
inputs=[
IO.String.Input(

View File

@ -29,6 +29,11 @@ from comfy_api_nodes.util import (
)
_GROK_VIDEO_MODEL_API_IDS = {
"grok-imagine-video-1.5": "grok-imagine-video-1.5-preview",
}
def _extract_grok_price(response) -> float | None:
if response.usage and response.usage.cost_in_usd_ticks is not None:
return response.usage.cost_in_usd_ticks / 10_000_000_000
@ -49,7 +54,7 @@ class GrokImageNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokImageNode",
display_name="Grok Image",
category="image/partner/Grok",
category="partner/image/Grok",
description="Generate images using Grok based on a text prompt",
inputs=[
IO.Combo.Input(
@ -58,7 +63,6 @@ class GrokImageNode(IO.ComfyNode):
"grok-imagine-image-quality",
"grok-imagine-image-pro",
"grok-imagine-image",
"grok-imagine-image-beta",
],
),
IO.String.Input(
@ -224,7 +228,7 @@ class GrokImageEditNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokImageEditNode",
display_name="Grok Image Edit",
category="image/partner/Grok",
category="partner/image/Grok",
description="Modify an existing image based on a text prompt",
inputs=[
IO.Combo.Input(
@ -233,7 +237,6 @@ class GrokImageEditNode(IO.ComfyNode):
"grok-imagine-image-quality",
"grok-imagine-image-pro",
"grok-imagine-image",
"grok-imagine-image-beta",
],
),
IO.Image.Input("image", display_name="images"),
@ -366,7 +369,7 @@ class GrokImageEditNodeV2(IO.ComfyNode):
return IO.Schema(
node_id="GrokImageEditNodeV2",
display_name="Grok Image Edit",
category="image/partner/Grok",
category="partner/image/Grok",
description="Modify an existing image based on a text prompt",
inputs=[
IO.String.Input(
@ -503,10 +506,14 @@ class GrokVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokVideoNode",
display_name="Grok Video",
category="video/partner/Grok",
category="partner/video/Grok",
description="Generate video from a prompt or an image",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
IO.Combo.Input(
"model",
options=["grok-imagine-video", "grok-imagine-video-1.5"],
tooltip="grok-imagine-video-1.5 currently always requires an input image.",
),
IO.String.Input(
"prompt",
multiline=True,
@ -542,7 +549,11 @@ class GrokVideoNode(IO.ComfyNode):
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
IO.Image.Input("image", optional=True),
IO.Image.Input(
"image",
optional=True,
tooltip="Optional starting image for grok-imagine-video. Required for grok-imagine-video-1.5.",
),
],
outputs=[
IO.Video.Output(),
@ -554,12 +565,16 @@ class GrokVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]),
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"], inputs=["image"]),
expr="""
(
$rate := widgets.resolution = "720p" ? 0.07 : 0.05;
$is15 := $contains(widgets.model, "1.5");
$rate := $is15
? (widgets.resolution = "720p" ? 0.2002 : 0.1144)
: (widgets.resolution = "720p" ? 0.07 : 0.05);
$imgCost := $is15 ? 0.0143 : 0.002;
$base := $rate * widgets.duration;
{"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base}
{"type":"usd","usd": inputs.image.connected ? $base + $imgCost : $base}
)
""",
),
@ -576,8 +591,8 @@ class GrokVideoNode(IO.ComfyNode):
seed: int,
image: Input.Image | None = None,
) -> IO.NodeOutput:
if model == "grok-imagine-video-beta":
model = "grok-imagine-video"
if image is None and model == "grok-imagine-video-1.5":
raise ValueError(f"The '{model}' model requires an input image; connect one to the 'image' input.")
image_url = None
if image is not None:
if get_number_of_images(image) != 1:
@ -588,7 +603,7 @@ class GrokVideoNode(IO.ComfyNode):
cls,
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
data=VideoGenerationRequest(
model=model,
model=_GROK_VIDEO_MODEL_API_IDS.get(model, model),
image=image_url,
prompt=prompt,
resolution=resolution,
@ -603,7 +618,7 @@ class GrokVideoNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_price,
price_extractor=_extract_grok_video_price if model == "grok-imagine-video-1.5" else _extract_grok_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
@ -615,10 +630,10 @@ class GrokVideoEditNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokVideoEditNode",
display_name="Grok Video Edit",
category="video/partner/Grok",
category="partner/video/Grok",
description="Edit an existing video based on a text prompt.",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
IO.Combo.Input("model", options=["grok-imagine-video"]),
IO.String.Input(
"prompt",
multiline=True,
@ -693,7 +708,7 @@ class GrokVideoReferenceNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokVideoReferenceNode",
display_name="Grok Reference-to-Video",
category="video/partner/Grok",
category="partner/video/Grok",
description="Generate video guided by reference images as style and content references.",
inputs=[
IO.String.Input(
@ -826,7 +841,7 @@ class GrokVideoExtendNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokVideoExtendNode",
display_name="Grok Video Extend",
category="video/partner/Grok",
category="partner/video/Grok",
description="Extend an existing video with a seamless continuation based on a text prompt.",
inputs=[
IO.String.Input(

View File

@ -71,7 +71,7 @@ class HitPawGeneralImageEnhance(IO.ComfyNode):
return IO.Schema(
node_id="HitPawGeneralImageEnhance",
display_name="HitPaw General Image Enhance",
category="image/partner/HitPaw",
category="partner/image/HitPaw",
description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. "
f"Maximum output: {MAX_MP_GENERATIVE} megapixels.",
inputs=[
@ -201,7 +201,7 @@ class HitPawVideoEnhance(IO.ComfyNode):
return IO.Schema(
node_id="HitPawVideoEnhance",
display_name="HitPaw Video Enhance",
category="video/partner/HitPaw",
category="partner/video/HitPaw",
description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. "
"Prices shown are per second of video.",
inputs=[

View File

@ -123,7 +123,7 @@ class TencentTextToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="TencentTextToModelNode",
display_name="Hunyuan3D: Text to Model",
category="3d/partner/Tencent",
category="partner/3d/Tencent",
essentials_category="3D",
inputs=[
IO.Combo.Input(
@ -242,7 +242,7 @@ class TencentImageToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="TencentImageToModelNode",
display_name="Hunyuan3D: Image(s) to Model",
category="3d/partner/Tencent",
category="partner/3d/Tencent",
essentials_category="3D",
inputs=[
IO.Combo.Input(
@ -415,7 +415,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
return IO.Schema(
node_id="TencentModelTo3DUVNode",
display_name="Hunyuan3D: Model to UV",
category="3d/partner/Tencent",
category="partner/3d/Tencent",
description="Perform UV unfolding on a 3D model to generate UV texture. "
"Input model must have less than 30000 faces.",
inputs=[
@ -505,7 +505,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
return IO.Schema(
node_id="Tencent3DTextureEditNode",
display_name="Hunyuan3D: 3D Texture Edit",
category="3d/partner/Tencent",
category="partner/3d/Tencent",
description="After inputting the 3D model, perform 3D model texture redrawing.",
inputs=[
IO.MultiType.Input(
@ -594,7 +594,7 @@ class Tencent3DPartNode(IO.ComfyNode):
return IO.Schema(
node_id="Tencent3DPartNode",
display_name="Hunyuan3D: 3D Part",
category="3d/partner/Tencent",
category="partner/3d/Tencent",
description="Automatically perform component identification and generation based on the model structure.",
inputs=[
IO.MultiType.Input(
@ -666,7 +666,7 @@ class TencentSmartTopologyNode(IO.ComfyNode):
return IO.Schema(
node_id="TencentSmartTopologyNode",
display_name="Hunyuan3D: Smart Topology",
category="3d/partner/Tencent",
category="partner/3d/Tencent",
description="Perform smart retopology on a 3D model. "
"Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.",
inputs=[

View File

@ -234,7 +234,7 @@ class IdeogramV1(IO.ComfyNode):
return IO.Schema(
node_id="IdeogramV1",
display_name="Ideogram V1",
category="image/partner/Ideogram",
category="partner/image/Ideogram",
description="Generates images using the Ideogram V1 model.",
inputs=[
IO.String.Input(
@ -360,7 +360,7 @@ class IdeogramV2(IO.ComfyNode):
return IO.Schema(
node_id="IdeogramV2",
display_name="Ideogram V2",
category="image/partner/Ideogram",
category="partner/image/Ideogram",
description="Generates images using the Ideogram V2 model.",
inputs=[
IO.String.Input(
@ -526,7 +526,7 @@ class IdeogramV3(IO.ComfyNode):
return IO.Schema(
node_id="IdeogramV3",
display_name="Ideogram V3",
category="image/partner/Ideogram",
category="partner/image/Ideogram",
description="Generates images using the Ideogram V3 model. "
"Supports both regular image generation from text prompts and image editing with mask.",
inputs=[

View File

@ -642,7 +642,7 @@ class KlingCameraControls(IO.ComfyNode):
return IO.Schema(
node_id="KlingCameraControls",
display_name="Kling Camera Controls",
category="video/partner/Kling",
category="partner/video/Kling",
description="Allows specifying configuration options for Kling Camera Controls and motion control effects.",
inputs=[
IO.Combo.Input("camera_control_type", options=KlingCameraControlType),
@ -762,7 +762,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingTextToVideoNode",
display_name="Kling Text to Video",
category="video/partner/Kling",
category="partner/video/Kling",
description="Kling Text to Video Node",
inputs=[
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -849,7 +849,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProTextToVideoNode",
display_name="Kling 3.0 Omni Text to Video",
category="video/partner/Kling",
category="partner/video/Kling",
description="Use text prompts to generate videos with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
@ -998,7 +998,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProFirstLastFrameNode",
display_name="Kling 3.0 Omni First-Last-Frame to Video",
category="video/partner/Kling",
category="partner/video/Kling",
description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
@ -1205,7 +1205,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProImageToVideoNode",
display_name="Kling 3.0 Omni Image to Video",
category="video/partner/Kling",
category="partner/video/Kling",
description="Use up to 7 reference images to generate a video with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
@ -1374,7 +1374,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProVideoToVideoNode",
display_name="Kling 3.0 Omni Video to Video",
category="video/partner/Kling",
category="partner/video/Kling",
description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
@ -1485,7 +1485,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProEditVideoNode",
display_name="Kling 3.0 Omni Edit Video",
category="video/partner/Kling",
category="partner/video/Kling",
essentials_category="Video Generation",
description="Edit an existing video with the latest model from Kling.",
inputs=[
@ -1593,7 +1593,7 @@ class OmniProImageNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProImageNode",
display_name="Kling 3.0 Omni Image",
category="image/partner/Kling",
category="partner/image/Kling",
description="Create or edit images with the latest model from Kling.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]),
@ -1721,7 +1721,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingCameraControlT2VNode",
display_name="Kling Text to Video (Camera Control)",
category="video/partner/Kling",
category="partner/video/Kling",
description="Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text.",
inputs=[
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -1783,7 +1783,7 @@ class KlingImage2VideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingImage2VideoNode",
display_name="Kling Image(First Frame) to Video",
category="video/partner/Kling",
category="partner/video/Kling",
inputs=[
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -1882,7 +1882,7 @@ class KlingCameraControlI2VNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingCameraControlI2VNode",
display_name="Kling Image to Video (Camera Control)",
category="video/partner/Kling",
category="partner/video/Kling",
description="Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image.",
inputs=[
IO.Image.Input(
@ -1953,7 +1953,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingStartEndFrameNode",
display_name="Kling Start-End Frame to Video",
category="video/partner/Kling",
category="partner/video/Kling",
description="Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last.",
inputs=[
IO.Image.Input(
@ -2047,7 +2047,7 @@ class KlingVideoExtendNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingVideoExtendNode",
display_name="Kling Video Extend",
category="video/partner/Kling",
category="partner/video/Kling",
description="Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes.",
inputs=[
IO.String.Input(
@ -2128,7 +2128,7 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingDualCharacterVideoEffectNode",
display_name="Kling Dual Character Video Effects",
category="video/partner/Kling",
category="partner/video/Kling",
description="Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite.",
inputs=[
IO.Image.Input("image_left", tooltip="Left side image"),
@ -2218,7 +2218,7 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingSingleImageVideoEffectNode",
display_name="Kling Video Effects",
category="video/partner/Kling",
category="partner/video/Kling",
description="Achieve different special effects when generating a video based on the effect_scene.",
inputs=[
IO.Image.Input(
@ -2291,7 +2291,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingLipSyncAudioToVideoNode",
display_name="Kling Lip Sync Video with Audio",
category="video/partner/Kling",
category="partner/video/Kling",
essentials_category="Video Generation",
description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
inputs=[
@ -2343,7 +2343,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingLipSyncTextToVideoNode",
display_name="Kling Lip Sync Video with Text",
category="video/partner/Kling",
category="partner/video/Kling",
description="Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
inputs=[
IO.Video.Input("video"),
@ -2411,7 +2411,7 @@ class KlingVirtualTryOnNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingVirtualTryOnNode",
display_name="Kling Virtual Try On",
category="image/partner/Kling",
category="partner/image/Kling",
description="Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background.",
inputs=[
IO.Image.Input("human_image"),
@ -2478,7 +2478,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingImageGenerationNode",
display_name="Kling 3.0 Image",
category="image/partner/Kling",
category="partner/image/Kling",
description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.",
inputs=[
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -2615,7 +2615,7 @@ class TextToVideoWithAudio(IO.ComfyNode):
return IO.Schema(
node_id="KlingTextToVideoWithAudio",
display_name="Kling 2.6 Text to Video with Audio",
category="video/partner/Kling",
category="partner/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
@ -2683,7 +2683,7 @@ class ImageToVideoWithAudio(IO.ComfyNode):
return IO.Schema(
node_id="KlingImageToVideoWithAudio",
display_name="Kling 2.6 Image(First Frame) to Video with Audio",
category="video/partner/Kling",
category="partner/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.Image.Input("start_frame"),
@ -2753,7 +2753,7 @@ class MotionControl(IO.ComfyNode):
return IO.Schema(
node_id="KlingMotionControl",
display_name="Kling Motion Control",
category="video/partner/Kling",
category="partner/video/Kling",
inputs=[
IO.String.Input("prompt", multiline=True),
IO.Image.Input("reference_image"),
@ -2854,7 +2854,7 @@ class KlingVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingVideoNode",
display_name="Kling 3.0 Video",
category="video/partner/Kling",
category="partner/video/Kling",
description="Generate videos with Kling V3. "
"Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.",
inputs=[
@ -3077,7 +3077,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingFirstLastFrameNode",
display_name="Kling 3.0 First-Last-Frame to Video",
category="video/partner/Kling",
category="partner/video/Kling",
description="Generate videos with Kling V3 using first and last frames.",
inputs=[
IO.String.Input("prompt", multiline=True, default=""),
@ -3202,7 +3202,7 @@ class KlingAvatarNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingAvatarNode",
display_name="Kling Avatar 2.0",
category="video/partner/Kling",
category="partner/video/Kling",
description="Generate broadcast-style digital human videos from a single photo and an audio file.",
inputs=[
IO.Image.Input(

View File

@ -106,7 +106,7 @@ class Krea2ImageNode(IO.ComfyNode):
return IO.Schema(
node_id="Krea2ImageNode",
display_name="Krea 2 Image",
category="image/partner/Krea",
category="partner/image/Krea",
description=(
"Generate images via Krea 2 — pick Medium (expressive illustrations) or "
"Large (expressive photorealism). Supports an optional moodboard and up "
@ -229,7 +229,7 @@ class Krea2StyleReferenceNode(IO.ComfyNode):
return IO.Schema(
node_id="Krea2StyleReferenceNode",
display_name="Krea 2 Style Reference",
category="image/partner/Krea",
category="partner/image/Krea",
description=(
"Add an image style reference to a Krea 2 generation. Chain multiple Krea 2 "
"Style Reference nodes (max 10) and feed the final `style_reference` output "

View File

@ -50,7 +50,7 @@ class TextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="LtxvApiTextToVideo",
display_name="LTXV Text To Video",
category="video/partner/LTXV",
category="partner/video/LTXV",
description="Professional-quality videos with customizable duration and resolution.",
inputs=[
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
@ -127,7 +127,7 @@ class ImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="LtxvApiImageToVideo",
display_name="LTXV Image To Video",
category="video/partner/LTXV",
category="partner/video/LTXV",
description="Professional-quality videos with customizable duration and resolution based on start image.",
inputs=[
IO.Image.Input("image", tooltip="First frame to be used for the video."),

View File

@ -46,7 +46,7 @@ class LumaReferenceNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaReferenceNode",
display_name="Luma Reference",
category="image/partner/Luma",
category="partner/image/Luma",
description="Holds an image and weight for use with Luma Generate Image node.",
inputs=[
IO.Image.Input(
@ -85,7 +85,7 @@ class LumaConceptsNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaConceptsNode",
display_name="Luma Concepts",
category="video/partner/Luma",
category="partner/video/Luma",
description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
inputs=[
IO.Combo.Input(
@ -134,7 +134,7 @@ class LumaImageGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaImageNode",
display_name="Luma Text to Image",
category="image/partner/Luma",
category="partner/image/Luma",
description="Generates images synchronously based on prompt and aspect ratio.",
inputs=[
IO.String.Input(
@ -278,7 +278,7 @@ class LumaImageModifyNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaImageModifyNode",
display_name="Luma Image to Image",
category="image/partner/Luma",
category="partner/image/Luma",
description="Modifies images synchronously based on prompt and aspect ratio.",
inputs=[
IO.Image.Input(
@ -371,7 +371,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaVideoNode",
display_name="Luma Text to Video",
category="video/partner/Luma",
category="partner/video/Luma",
description="Generates videos synchronously based on prompt and output_size.",
inputs=[
IO.String.Input(
@ -472,7 +472,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaImageToVideoNode",
display_name="Luma Image to Video",
category="video/partner/Luma",
category="partner/video/Luma",
description="Generates videos synchronously based on prompt, input images, and output_size.",
inputs=[
IO.String.Input(
@ -724,7 +724,7 @@ class LumaImageNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaImageNode2",
display_name="Luma UNI-1 Image",
category="image/partner/Luma",
category="partner/image/Luma",
description="Generate images from text using the Luma UNI-1 model.",
inputs=[
IO.String.Input(
@ -853,7 +853,7 @@ class LumaImageEditNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaImageEditNode2",
display_name="Luma UNI-1 Image Edit",
category="image/partner/Luma",
category="partner/image/Luma",
description="Edit an existing image with a text prompt using the Luma UNI-1 model.",
inputs=[
IO.Image.Input(

View File

@ -61,7 +61,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
return IO.Schema(
node_id="MagnificImageUpscalerCreativeNode",
display_name="Magnific Image Upscale (Creative)",
category="image/partner/Magnific",
category="partner/image/Magnific",
description="Promptguided enhancement, stylization, and 2x/4x/8x/16x upscaling. "
"Maximum output: 25.3 megapixels.",
inputs=[
@ -240,7 +240,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
return IO.Schema(
node_id="MagnificImageUpscalerPreciseV2Node",
display_name="Magnific Image Upscale (Precise V2)",
category="image/partner/Magnific",
category="partner/image/Magnific",
description="High-fidelity upscaling with fine control over sharpness, grain, and detail. "
"Maximum output: 10060×10060 pixels.",
inputs=[
@ -400,7 +400,7 @@ class MagnificImageStyleTransferNode(IO.ComfyNode):
return IO.Schema(
node_id="MagnificImageStyleTransferNode",
display_name="Magnific Image Style Transfer",
category="image/partner/Magnific",
category="partner/image/Magnific",
description="Transfer the style from a reference image to your input image.",
inputs=[
IO.Image.Input("image", tooltip="The image to apply style transfer to."),
@ -549,7 +549,7 @@ class MagnificImageRelightNode(IO.ComfyNode):
return IO.Schema(
node_id="MagnificImageRelightNode",
display_name="Magnific Image Relight",
category="image/partner/Magnific",
category="partner/image/Magnific",
description="Relight an image with lighting adjustments and optional reference-based light transfer.",
inputs=[
IO.Image.Input("image", tooltip="The image to relight."),
@ -789,7 +789,7 @@ class MagnificImageSkinEnhancerNode(IO.ComfyNode):
return IO.Schema(
node_id="MagnificImageSkinEnhancerNode",
display_name="Magnific Image Skin Enhancer",
category="image/partner/Magnific",
category="partner/image/Magnific",
description="Skin enhancement for portraits with multiple processing modes.",
inputs=[
IO.Image.Input("image", tooltip="The portrait image to enhance."),

View File

@ -33,7 +33,7 @@ class MeshyTextToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyTextToModelNode",
display_name="Meshy: Text to Model",
category="3d/partner/Meshy",
category="partner/3d/Meshy",
inputs=[
IO.Combo.Input("model", options=["latest"]),
IO.String.Input("prompt", multiline=True, default=""),
@ -145,7 +145,7 @@ class MeshyRefineNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyRefineNode",
display_name="Meshy: Refine Draft Model",
category="3d/partner/Meshy",
category="partner/3d/Meshy",
description="Refine a previously created draft model.",
inputs=[
IO.Combo.Input("model", options=["latest"]),
@ -240,7 +240,7 @@ class MeshyImageToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyImageToModelNode",
display_name="Meshy: Image to Model",
category="3d/partner/Meshy",
category="partner/3d/Meshy",
inputs=[
IO.Combo.Input("model", options=["latest"]),
IO.Image.Input("image"),
@ -405,7 +405,7 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyMultiImageToModelNode",
display_name="Meshy: Multi-Image to Model",
category="3d/partner/Meshy",
category="partner/3d/Meshy",
inputs=[
IO.Combo.Input("model", options=["latest"]),
IO.Autogrow.Input(
@ -575,7 +575,7 @@ class MeshyRigModelNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyRigModelNode",
display_name="Meshy: Rig Model",
category="3d/partner/Meshy",
category="partner/3d/Meshy",
description="Provides a rigged character in standard formats. "
"Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, "
"or humanoid assets with unclear limb and body structure.",
@ -656,7 +656,7 @@ class MeshyAnimateModelNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyAnimateModelNode",
display_name="Meshy: Animate Model",
category="3d/partner/Meshy",
category="partner/3d/Meshy",
description="Apply a specific animation action to a previously rigged character.",
inputs=[
IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"),
@ -722,7 +722,7 @@ class MeshyTextureNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyTextureNode",
display_name="Meshy: Texture Model",
category="3d/partner/Meshy",
category="partner/3d/Meshy",
inputs=[
IO.Combo.Input("model", options=["latest"]),
IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"),

View File

@ -101,7 +101,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="MinimaxTextToVideoNode",
display_name="MiniMax Text to Video",
category="video/partner/MiniMax",
category="partner/video/MiniMax",
description="Generates videos synchronously based on a prompt, and optional parameters.",
inputs=[
IO.String.Input(
@ -163,7 +163,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="MinimaxImageToVideoNode",
display_name="MiniMax Image to Video",
category="video/partner/MiniMax",
category="partner/video/MiniMax",
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[
IO.Image.Input(
@ -230,7 +230,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="MinimaxSubjectToVideoNode",
display_name="MiniMax Subject to Video",
category="video/partner/MiniMax",
category="partner/video/MiniMax",
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[
IO.Image.Input(
@ -294,7 +294,7 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="MinimaxHailuoVideoNode",
display_name="MiniMax Hailuo Video",
category="video/partner/MiniMax",
category="partner/video/MiniMax",
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
inputs=[
IO.String.Input(

View File

@ -99,7 +99,7 @@ class OpenAIDalle2(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIDalle2",
display_name="OpenAI DALL·E 2",
category="image/partner/OpenAI",
category="partner/image/OpenAI",
description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.",
inputs=[
IO.String.Input(
@ -249,7 +249,7 @@ class OpenAIDalle3(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIDalle3",
display_name="OpenAI DALL·E 3",
category="image/partner/OpenAI",
category="partner/image/OpenAI",
description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.",
inputs=[
IO.String.Input(
@ -371,7 +371,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIGPTImage1",
display_name="OpenAI GPT Image 2",
category="image/partner/OpenAI",
category="partner/image/OpenAI",
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
is_deprecated=True,
inputs=[
@ -695,7 +695,7 @@ class OpenAIGPTImageNodeV2(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIGPTImageNodeV2",
display_name="OpenAI GPT Image 2",
category="image/partner/OpenAI",
category="partner/image/OpenAI",
description="Generates images via OpenAI's GPT Image endpoint.",
inputs=[
IO.String.Input(
@ -962,7 +962,7 @@ class OpenAIChatNode(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIChatNode",
display_name="OpenAI ChatGPT",
category="text/partner/OpenAI",
category="partner/text/OpenAI",
essentials_category="Text Generation",
description="Generate text responses from an OpenAI model.",
inputs=[
@ -1201,7 +1201,7 @@ class OpenAIInputFiles(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIInputFiles",
display_name="OpenAI ChatGPT Input Files",
category="text/partner/OpenAI",
category="partner/text/OpenAI",
description="Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes.",
inputs=[
IO.Combo.Input(
@ -1248,7 +1248,7 @@ class OpenAIChatConfig(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIChatConfig",
display_name="OpenAI ChatGPT Advanced Options",
category="text/partner/OpenAI",
category="partner/text/OpenAI",
description="Allows specifying advanced configuration options for the OpenAI Chat Nodes.",
inputs=[
IO.Combo.Input(

View File

@ -265,7 +265,7 @@ class OpenRouterLLMNode(IO.ComfyNode):
return IO.Schema(
node_id="OpenRouterLLMNode",
display_name="OpenRouter LLM",
category="text/partner/OpenRouter",
category="partner/text/OpenRouter",
essentials_category="Text Generation",
description=(
"Generate text responses through OpenRouter. Routes to a curated set of popular "

View File

@ -53,7 +53,7 @@ class PixverseTemplateNode(IO.ComfyNode):
return IO.Schema(
node_id="PixverseTemplateNode",
display_name="PixVerse Template",
category="video/partner/PixVerse",
category="partner/video/PixVerse",
inputs=[
IO.Combo.Input("template", options=list(pixverse_templates.keys())),
],
@ -74,7 +74,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video",
category="video/partner/PixVerse",
category="partner/video/PixVerse",
description="Generates videos based on prompt and output_size.",
inputs=[
IO.String.Input(
@ -192,7 +192,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video",
category="video/partner/PixVerse",
category="partner/video/PixVerse",
description="Generates videos based on prompt and output_size.",
inputs=[
IO.Image.Input("image"),
@ -310,7 +310,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video",
category="video/partner/PixVerse",
category="partner/video/PixVerse",
description="Generates videos based on prompt and output_size.",
inputs=[
IO.Image.Input("first_frame"),

View File

@ -62,7 +62,7 @@ class QuiverTextToSVGNode(IO.ComfyNode):
return IO.Schema(
node_id="QuiverTextToSVGNode",
display_name="Quiver Text to SVG",
category="image/partner/Quiver",
category="partner/image/Quiver",
description="Generate an SVG from a text prompt using Quiver AI.",
inputs=[
IO.String.Input(
@ -177,7 +177,7 @@ class QuiverImageToSVGNode(IO.ComfyNode):
return IO.Schema(
node_id="QuiverImageToSVGNode",
display_name="Quiver Image to SVG",
category="image/partner/Quiver",
category="partner/image/Quiver",
description="Vectorize a raster image into SVG using Quiver AI.",
inputs=[
IO.Image.Input(

View File

@ -178,7 +178,7 @@ class RecraftColorRGBNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftColorRGB",
display_name="Recraft Color RGB",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Create Recraft Color by choosing specific RGB values.",
inputs=[
IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."),
@ -204,7 +204,7 @@ class RecraftControlsNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftControls",
display_name="Recraft Controls",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Create Recraft Controls for customizing Recraft generation.",
inputs=[
IO.Custom(RecraftIO.COLOR).Input("colors", optional=True),
@ -228,7 +228,7 @@ class RecraftStyleV3RealisticImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftStyleV3RealisticImage",
display_name="Recraft Style - Realistic Image",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Select realistic_image style and optional substyle.",
inputs=[
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
@ -253,7 +253,7 @@ class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode):
return IO.Schema(
node_id="RecraftStyleV3DigitalIllustration",
display_name="Recraft Style - Digital Illustration",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Select realistic_image style and optional substyle.",
inputs=[
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
@ -272,7 +272,7 @@ class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode):
return IO.Schema(
node_id="RecraftStyleV3VectorIllustrationNode",
display_name="Recraft Style - Realistic Image",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Select realistic_image style and optional substyle.",
inputs=[
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
@ -291,7 +291,7 @@ class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode):
return IO.Schema(
node_id="RecraftStyleV3LogoRaster",
display_name="Recraft Style - Logo Raster",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Select realistic_image style and optional substyle.",
inputs=[
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)),
@ -308,7 +308,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
return IO.Schema(
node_id="RecraftStyleV3InfiniteStyleLibrary",
display_name="Recraft Style - Infinite Style Library",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.",
inputs=[
IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."),
@ -331,7 +331,7 @@ class RecraftCreateStyleNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftCreateStyleNode",
display_name="Recraft Create Style",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Create a custom style from reference images. "
"Upload 1-5 images to use as style references. "
"Total size of all images is limited to 5 MB.",
@ -400,7 +400,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftTextToImageNode",
display_name="Recraft Text to Image",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Generates images synchronously based on prompt and resolution.",
inputs=[
IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."),
@ -512,7 +512,7 @@ class RecraftImageToImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftImageToImageNode",
display_name="Recraft Image to Image",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Modify image based on prompt and strength.",
inputs=[
IO.Image.Input("image"),
@ -630,7 +630,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftImageInpaintingNode",
display_name="Recraft Image Inpainting",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Modify image based on prompt and mask.",
inputs=[
IO.Image.Input("image"),
@ -732,7 +732,7 @@ class RecraftTextToVectorNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftTextToVectorNode",
display_name="Recraft Text to Vector",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Generates SVG synchronously based on prompt and resolution.",
inputs=[
IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True),
@ -832,7 +832,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftVectorizeImageNode",
display_name="Recraft Vectorize Image",
category="image/partner/Recraft",
category="partner/image/Recraft",
essentials_category="Image Tools",
description="Generates SVG synchronously from an input image.",
inputs=[
@ -876,7 +876,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftReplaceBackgroundNode",
display_name="Recraft Replace Background",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Replace background on image, based on provided prompt.",
inputs=[
IO.Image.Input("image"),
@ -963,7 +963,7 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftRemoveBackgroundNode",
display_name="Recraft Remove Background",
category="image/partner/Recraft",
category="partner/image/Recraft",
essentials_category="Image Tools",
description="Remove background from image, and return processed image and mask.",
inputs=[
@ -1012,7 +1012,7 @@ class RecraftCrispUpscaleNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftCrispUpscaleNode",
display_name="Recraft Crisp Upscale Image",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Upscale image synchronously.\n"
"Enhances a given raster image using crisp upscale tool, "
"increasing image resolution, making the image sharper and cleaner.",
@ -1058,7 +1058,7 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
return IO.Schema(
node_id="RecraftCreativeUpscaleNode",
display_name="Recraft Creative Upscale Image",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Upscale image synchronously.\n"
"Enhances a given raster image using creative upscale tool, "
"boosting resolution with a focus on refining small details and faces.",
@ -1086,7 +1086,7 @@ class RecraftV4TextToImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftV4TextToImageNode",
display_name="Recraft V4 Text to Image",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Generates images using Recraft V4 or V4 Pro models.",
inputs=[
IO.String.Input(
@ -1210,7 +1210,7 @@ class RecraftV4TextToVectorNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftV4TextToVectorNode",
display_name="Recraft V4 Text to Vector",
category="image/partner/Recraft",
category="partner/image/Recraft",
description="Generates SVG using Recraft V4 or V4 Pro models.",
inputs=[
IO.String.Input(

View File

@ -109,7 +109,7 @@ class ReveImageCreateNode(IO.ComfyNode):
return IO.Schema(
node_id="ReveImageCreateNode",
display_name="Reve Image Create",
category="image/partner/Reve",
category="partner/image/Reve",
description="Generate images from text descriptions using Reve.",
inputs=[
IO.String.Input(
@ -200,7 +200,7 @@ class ReveImageEditNode(IO.ComfyNode):
return IO.Schema(
node_id="ReveImageEditNode",
display_name="Reve Image Edit",
category="image/partner/Reve",
category="partner/image/Reve",
description="Edit images using natural language instructions with Reve.",
inputs=[
IO.Image.Input("image", tooltip="The image to edit."),
@ -300,7 +300,7 @@ class ReveImageRemixNode(IO.ComfyNode):
return IO.Schema(
node_id="ReveImageRemixNode",
display_name="Reve Image Remix",
category="image/partner/Reve",
category="partner/image/Reve",
description="Combine reference images with text prompts to create new images using Reve.",
inputs=[
IO.Autogrow.Input(

View File

@ -230,7 +230,7 @@ class Rodin3D_Regular(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Regular",
display_name="Rodin 3D Generate - Regular Generate",
category="3d/partner/Rodin",
category="partner/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
@ -289,7 +289,7 @@ class Rodin3D_Detail(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Detail",
display_name="Rodin 3D Generate - Detail Generate",
category="3d/partner/Rodin",
category="partner/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
@ -348,7 +348,7 @@ class Rodin3D_Smooth(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Smooth",
display_name="Rodin 3D Generate - Smooth Generate",
category="3d/partner/Rodin",
category="partner/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
@ -406,7 +406,7 @@ class Rodin3D_Sketch(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Sketch",
display_name="Rodin 3D Generate - Sketch Generate",
category="3d/partner/Rodin",
category="partner/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
@ -468,7 +468,7 @@ class Rodin3D_Gen2(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Gen2",
display_name="Rodin 3D Generate - Gen-2 Generate",
category="3d/partner/Rodin",
category="partner/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
@ -941,7 +941,7 @@ class Rodin3D_Gen25_Image(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Gen25_Image",
display_name="Rodin 3D Gen-2.5 - Image to 3D",
category="3d/partner/Rodin",
category="partner/3d/Rodin",
description=(
"Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. "
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
@ -1035,7 +1035,7 @@ class Rodin3D_Gen25_Text(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Gen25_Text",
display_name="Rodin 3D Gen-2.5 - Text to 3D",
category="3d/partner/Rodin",
category="partner/3d/Rodin",
description=(
"Generate a 3D model from a text prompt via Rodin Gen-2.5. "
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."

View File

@ -140,7 +140,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
return IO.Schema(
node_id="RunwayImageToVideoNodeGen3a",
display_name="Runway Image to Video (Gen3a Turbo)",
category="video/partner/Runway",
category="partner/video/Runway",
description="Generate a video from a single starting frame using Gen3a Turbo model. "
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
@ -234,7 +234,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
return IO.Schema(
node_id="RunwayImageToVideoNodeGen4",
display_name="Runway Image to Video (Gen4 Turbo)",
category="video/partner/Runway",
category="partner/video/Runway",
description="Generate a video from a single starting frame using Gen4 Turbo model. "
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
@ -329,7 +329,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="RunwayFirstLastFrameNode",
display_name="Runway First-Last-Frame to Video",
category="video/partner/Runway",
category="partner/video/Runway",
description="Upload first and last keyframes, draft a prompt, and generate a video. "
"More complex transitions, such as cases where the Last frame is completely different "
"from the First frame, may benefit from the longer 10s duration. "
@ -440,7 +440,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RunwayTextToImageNode",
display_name="Runway Text to Image",
category="image/partner/Runway",
category="partner/image/Runway",
description="Generate an image from a text prompt using Runway's Gen 4 model. "
"You can also include reference image to guide the generation.",
inputs=[

View File

@ -34,7 +34,7 @@ class SoniloVideoToMusic(IO.ComfyNode):
return IO.Schema(
node_id="SoniloVideoToMusic",
display_name="Sonilo Video to Music",
category="audio/partner/Sonilo",
category="partner/audio/Sonilo",
description="Generate music from video content using Sonilo's AI model. "
"Analyzes the video and creates matching music.",
inputs=[
@ -99,7 +99,7 @@ class SoniloTextToMusic(IO.ComfyNode):
return IO.Schema(
node_id="SoniloTextToMusic",
display_name="Sonilo Text to Music",
category="audio/partner/Sonilo",
category="partner/audio/Sonilo",
description="Generate music from a text prompt using Sonilo's AI model. "
"Leave duration at 0 to let the model infer it from the prompt.",
inputs=[

View File

@ -34,7 +34,7 @@ class OpenAIVideoSora2(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIVideoSora2",
display_name="OpenAI Sora - Video (DEPRECATED)",
category="video/partner/Sora",
category="partner/video/Sora",
description=(
"OpenAI video and audio generation.\n\n"
"DEPRECATION NOTICE: OpenAI will stop serving the Sora v2 API in September 2026. "

View File

@ -62,7 +62,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
return IO.Schema(
node_id="StabilityStableImageUltraNode",
display_name="Stability AI Stable Image Ultra",
category="image/partner/Stability AI",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
@ -197,7 +197,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
return IO.Schema(
node_id="StabilityStableImageSD_3_5Node",
display_name="Stability AI Stable Diffusion 3.5 Image",
category="image/partner/Stability AI",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
@ -354,7 +354,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
return IO.Schema(
node_id="StabilityUpscaleConservativeNode",
display_name="Stability AI Upscale Conservative",
category="image/partner/Stability AI",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
@ -457,7 +457,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
return IO.Schema(
node_id="StabilityUpscaleCreativeNode",
display_name="Stability AI Upscale Creative",
category="image/partner/Stability AI",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
@ -578,7 +578,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
return IO.Schema(
node_id="StabilityUpscaleFastNode",
display_name="Stability AI Upscale Fast",
category="image/partner/Stability AI",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
@ -630,7 +630,7 @@ class StabilityTextToAudio(IO.ComfyNode):
return IO.Schema(
node_id="StabilityTextToAudio",
display_name="Stability AI Text To Audio",
category="audio/partner/Stability AI",
category="partner/audio/Stability AI",
essentials_category="Audio",
description=cleandoc(cls.__doc__ or ""),
inputs=[
@ -708,7 +708,7 @@ class StabilityAudioToAudio(IO.ComfyNode):
return IO.Schema(
node_id="StabilityAudioToAudio",
display_name="Stability AI Audio To Audio",
category="audio/partner/Stability AI",
category="partner/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(
@ -802,7 +802,7 @@ class StabilityAudioInpaint(IO.ComfyNode):
return IO.Schema(
node_id="StabilityAudioInpaint",
display_name="Stability AI Audio Inpaint",
category="audio/partner/Stability AI",
category="partner/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(

View File

@ -52,7 +52,7 @@ class TopazImageEnhance(IO.ComfyNode):
return IO.Schema(
node_id="TopazImageEnhance",
display_name="Topaz Image Enhance",
category="image/partner/Topaz",
category="partner/image/Topaz",
description="Industry-standard upscaling and image enhancement.",
inputs=[
IO.Combo.Input("model", options=["Reimagine"]),
@ -235,7 +235,7 @@ class TopazVideoEnhance(IO.ComfyNode):
return IO.Schema(
node_id="TopazVideoEnhance",
display_name="Topaz Video Enhance (Legacy)",
category="video/partner/Topaz",
category="partner/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),
@ -475,7 +475,7 @@ class TopazVideoEnhanceV2(IO.ComfyNode):
return IO.Schema(
node_id="TopazVideoEnhanceV2",
display_name="Topaz Video Enhance",
category="video/partner/Topaz",
category="partner/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),

View File

@ -11,6 +11,9 @@ from comfy_api_nodes.apis.tripo import (
TripoModelVersion,
TripoMultiviewToModelRequest,
TripoOrientation,
TripoP1ImageToModelRequest,
TripoP1MultiviewToModelRequest,
TripoP1TextToModelRequest,
TripoRefineModelRequest,
TripoStyle,
TripoTaskResponse,
@ -80,7 +83,7 @@ class TripoTextToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoTextToModelNode",
display_name="Tripo: Text to Model",
category="3d/partner/Tripo",
category="partner/3d/Tripo",
inputs=[
IO.String.Input("prompt", multiline=True),
IO.String.Input("negative_prompt", multiline=True, optional=True),
@ -93,10 +96,22 @@ class TripoTextToModelNode(IO.ComfyNode):
IO.Int.Input("image_seed", default=42, optional=True, advanced=True),
IO.Int.Input("model_seed", default=42, optional=True, advanced=True),
IO.Int.Input("texture_seed", default=42, optional=True, advanced=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
IO.Combo.Input(
"texture_quality",
default="standard",
options=["standard", "detailed"],
optional=True,
advanced=True,
),
IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True, advanced=True),
IO.Boolean.Input("quad", default=False, optional=True, advanced=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
IO.Combo.Input(
"geometry_quality",
default="standard",
options=["standard", "detailed"],
optional=True,
advanced=True,
),
],
outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only
@ -195,7 +210,7 @@ class TripoImageToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoImageToModelNode",
display_name="Tripo: Image to Model",
category="3d/partner/Tripo",
category="partner/3d/Tripo",
inputs=[
IO.Image.Input("image"),
IO.Combo.Input(
@ -209,16 +224,36 @@ class TripoImageToModelNode(IO.ComfyNode):
IO.Boolean.Input("pbr", default=True, optional=True),
IO.Int.Input("model_seed", default=42, optional=True, advanced=True),
IO.Combo.Input(
"orientation", options=TripoOrientation, default=TripoOrientation.DEFAULT, optional=True, advanced=True
"orientation",
options=TripoOrientation,
default=TripoOrientation.DEFAULT,
optional=True,
advanced=True,
),
IO.Int.Input("texture_seed", default=42, optional=True, advanced=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
IO.Combo.Input(
"texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True
"texture_quality",
default="standard",
options=["standard", "detailed"],
optional=True,
advanced=True,
),
IO.Combo.Input(
"texture_alignment",
default="original_image",
options=["original_image", "geometry"],
optional=True,
advanced=True,
),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True),
IO.Boolean.Input("quad", default=False, optional=True, advanced=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
IO.Combo.Input(
"geometry_quality",
default="standard",
options=["standard", "detailed"],
optional=True,
advanced=True,
),
],
outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only
@ -323,7 +358,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoMultiviewToModelNode",
display_name="Tripo: Multiview to Model",
category="3d/partner/Tripo",
category="partner/3d/Tripo",
inputs=[
IO.Image.Input("image"),
IO.Image.Input("image_left", optional=True),
@ -346,13 +381,35 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
IO.Boolean.Input("pbr", default=True, optional=True),
IO.Int.Input("model_seed", default=42, optional=True, advanced=True),
IO.Int.Input("texture_seed", default=42, optional=True, advanced=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
IO.Combo.Input(
"texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True
"texture_quality",
default="standard",
options=["standard", "detailed"],
optional=True,
advanced=True,
),
IO.Combo.Input(
"texture_alignment",
default="original_image",
options=["original_image", "geometry"],
optional=True,
advanced=True,
),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True),
IO.Boolean.Input("quad", default=False, optional=True, advanced=True, tooltip="This parameter is deprecated and does nothing."),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
IO.Boolean.Input(
"quad",
default=False,
optional=True,
advanced=True,
tooltip="This parameter is deprecated and does nothing.",
),
IO.Combo.Input(
"geometry_quality",
default="standard",
options=["standard", "detailed"],
optional=True,
advanced=True,
),
],
outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only
@ -461,15 +518,25 @@ class TripoTextureNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoTextureNode",
display_name="Tripo: Texture model",
category="3d/partner/Tripo",
category="partner/3d/Tripo",
inputs=[
IO.Custom("MODEL_TASK_ID").Input("model_task_id"),
IO.Boolean.Input("texture", default=True, optional=True),
IO.Boolean.Input("pbr", default=True, optional=True),
IO.Int.Input("texture_seed", default=42, optional=True, advanced=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
IO.Combo.Input(
"texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True
"texture_quality",
default="standard",
options=["standard", "detailed"],
optional=True,
advanced=True,
),
IO.Combo.Input(
"texture_alignment",
default="original_image",
options=["original_image", "geometry"],
optional=True,
advanced=True,
),
],
outputs=[
@ -528,7 +595,7 @@ class TripoRefineNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoRefineNode",
display_name="Tripo: Refine Draft model",
category="3d/partner/Tripo",
category="partner/3d/Tripo",
description="Refine a draft model created by v1.4 Tripo models only.",
inputs=[
IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"),
@ -568,7 +635,7 @@ class TripoRigNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoRigNode",
display_name="Tripo: Rig model",
category="3d/partner/Tripo",
category="partner/3d/Tripo",
inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")],
outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only
@ -605,7 +672,7 @@ class TripoRetargetNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoRetargetNode",
display_name="Tripo: Retarget rigged model",
category="3d/partner/Tripo",
category="partner/3d/Tripo",
inputs=[
IO.Custom("RIG_TASK_ID").Input("original_model_task_id"),
IO.Combo.Input(
@ -626,7 +693,7 @@ class TripoRetargetNode(IO.ComfyNode):
"preset:hexapod:walk",
"preset:octopod:walk",
"preset:serpentine:march",
"preset:aquatic:march"
"preset:aquatic:march",
],
),
],
@ -670,7 +737,7 @@ class TripoConversionNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoConversionNode",
display_name="Tripo: Convert model",
category="3d/partner/Tripo",
category="partner/3d/Tripo",
inputs=[
IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"),
IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]),
@ -817,7 +884,7 @@ class TripoConversionNode(IO.ComfyNode):
# Parse part_names from comma-separated string to list
part_names_list = None
if part_names and part_names.strip():
part_names_list = [name.strip() for name in part_names.split(',') if name.strip()]
part_names_list = [name.strip() for name in part_names.split(",") if name.strip()]
response = await sync_op(
cls,
@ -848,6 +915,373 @@ class TripoConversionNode(IO.ComfyNode):
return await poll_until_finished(cls, response, average_duration=30)
def _p1_price_expr(*, geometry_credits: int, textured_credits: int, detailed_credits: int) -> str:
return (
"("
" $mode := widgets.output_mode;"
' $detailed := $lookup(widgets, "output_mode.texture_quality") = "detailed";'
f' $credits := $mode = "geometry only" ? {geometry_credits} : ($detailed ? {detailed_credits} : {textured_credits});'
' {"type":"usd","usd": $credits * 0.01, "format": {"approximate": true}}'
")"
)
def _p1_textured_inputs(*, include_image_alignment: bool) -> list:
"""Inputs shown inside the 'Textured' branch of the P1 output_mode DynamicCombo."""
inputs: list = [
IO.Boolean.Input("pbr", default=True, tooltip="Include PBR maps. When on, base texture is forced on too."),
IO.Combo.Input("texture_quality", options=["standard", "detailed"], default="standard"),
]
if include_image_alignment:
inputs.extend(
[
IO.Combo.Input(
"texture_alignment",
options=["original_image", "geometry"],
default="original_image",
tooltip="Prioritize visual fidelity to the source image, or alignment to the mesh geometry.",
),
IO.Combo.Input(
"orientation",
options=["default", "align_image"],
default="default",
tooltip="Rotate the output to match the source image. Only applies when textured.",
),
]
)
inputs.append(IO.Int.Input("texture_seed", default=42, advanced=True))
return inputs
def _build_p1_output_mode(*, include_image_alignment: bool) -> IO.DynamicCombo.Input:
return IO.DynamicCombo.Input(
"output_mode",
options=[
IO.DynamicCombo.Option("Geometry only", []),
IO.DynamicCombo.Option("Textured", _p1_textured_inputs(include_image_alignment=include_image_alignment)),
],
tooltip='"Geometry only" returns an untextured mesh. "Textured" adds color/PBR maps.',
)
def _resolve_p1_texture_fields(output_mode: dict) -> dict:
"""Translate the output_mode DynamicCombo payload into P1 request fields.
pbr=true forces texture=true server-side, but we send both explicitly so the
intent is visible in the request body and logs.
"""
mode = output_mode["output_mode"]
if mode == "Geometry only":
return {"texture": False, "pbr": False}
out = {
"texture": True,
"pbr": bool(output_mode.get("pbr", True)),
"texture_quality": output_mode.get("texture_quality", "standard"),
"texture_seed": output_mode.get("texture_seed"),
}
if "texture_alignment" in output_mode:
out["texture_alignment"] = output_mode["texture_alignment"]
if "orientation" in output_mode:
out["orientation"] = output_mode["orientation"]
return out
def _p1_common_inputs() -> list:
"""Inputs shared by all P1 nodes (placed after output_mode)."""
return [
IO.Int.Input(
"face_limit",
default=-1,
min=-1,
max=20000,
optional=True,
advanced=True,
tooltip="Target face count, 48-20000. -1 lets Tripo pick adaptively.",
),
IO.Int.Input("model_seed", default=42, optional=True, advanced=True),
IO.Boolean.Input(
"auto_size",
default=False,
optional=True,
advanced=True,
tooltip="Scale the output to approximate real-world meters.",
),
IO.Boolean.Input(
"export_uv",
default=True,
optional=True,
advanced=True,
tooltip="UV unwrap during generation. Turn off for faster geometry-only runs.",
),
IO.Boolean.Input(
"compress_geometry",
default=False,
optional=True,
advanced=True,
tooltip="Apply geometry-based compression. Decompress before editing.",
),
]
def _build_p1_request_kwargs(
*,
output_mode: dict,
face_limit: int,
model_seed: int,
auto_size: bool,
export_uv: bool,
compress_geometry: bool,
) -> dict:
"""Common P1 request fields shared by all three node types."""
kwargs: dict = {
"model_seed": model_seed,
"face_limit": face_limit if face_limit != -1 else None,
"auto_size": auto_size,
"export_uv": export_uv,
"compress": "geometry" if compress_geometry else None,
}
kwargs.update(_resolve_p1_texture_fields(output_mode))
return kwargs
class TripoP1TextToModelNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoP1TextToModelNode",
display_name="Tripo P1: Text to Model",
category="partner/3d/Tripo",
description="Tripo P1 text-to-3D. Optimized for low-poly, game-ready meshes with stable topology.",
inputs=[
IO.String.Input("prompt", multiline=True, tooltip="Up to 1024 characters."),
IO.String.Input("negative_prompt", multiline=True, optional=True, tooltip="Up to 255 characters."),
_build_p1_output_mode(include_image_alignment=False),
IO.Int.Input("image_seed", default=42, optional=True, advanced=True),
*_p1_common_inputs(),
],
outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["output_mode", "output_mode.texture_quality"]),
expr=_p1_price_expr(geometry_credits=30, textured_credits=40, detailed_credits=50),
),
)
@classmethod
async def execute(
cls,
prompt: str,
output_mode: dict,
negative_prompt: str | None = None,
image_seed: int | None = None,
face_limit: int = -1,
model_seed: int | None = None,
auto_size: bool = False,
export_uv: bool = True,
compress_geometry: bool = False,
) -> IO.NodeOutput:
if not prompt:
raise RuntimeError("Prompt is required")
common = _build_p1_request_kwargs(
output_mode=output_mode,
face_limit=face_limit,
model_seed=model_seed,
auto_size=auto_size,
export_uv=export_uv,
compress_geometry=compress_geometry,
)
request = TripoP1TextToModelRequest(
prompt=prompt,
negative_prompt=negative_prompt or None,
image_seed=image_seed,
**common,
)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
response_model=TripoTaskResponse,
data=request,
)
return await poll_until_finished(cls, response, average_duration=60)
class TripoP1ImageToModelNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoP1ImageToModelNode",
display_name="Tripo P1: Image to Model",
category="partner/3d/Tripo",
description="Tripo P1 image-to-3D. Optimized for low-poly, game-ready meshes.",
inputs=[
IO.Image.Input("image"),
_build_p1_output_mode(include_image_alignment=True),
IO.Boolean.Input(
"enable_image_autofix",
default=False,
optional=True,
advanced=True,
tooltip="Pre-process the input image for better generation quality.",
),
*_p1_common_inputs(),
],
outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["output_mode", "output_mode.texture_quality"]),
expr=_p1_price_expr(geometry_credits=40, textured_credits=50, detailed_credits=60),
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
output_mode: dict,
enable_image_autofix: bool = False,
face_limit: int = -1,
model_seed: int | None = None,
auto_size: bool = False,
export_uv: bool = True,
compress_geometry: bool = False,
) -> IO.NodeOutput:
if image is None:
raise RuntimeError("Image is required")
tripo_file = TripoFileReference(
root=TripoUrlReference(
url=(await upload_images_to_comfyapi(cls, image, max_images=1))[0],
type="jpeg",
)
)
common = _build_p1_request_kwargs(
output_mode=output_mode,
face_limit=face_limit,
model_seed=model_seed,
auto_size=auto_size,
export_uv=export_uv,
compress_geometry=compress_geometry,
)
request = TripoP1ImageToModelRequest(
file=tripo_file,
enable_image_autofix=enable_image_autofix,
**common,
)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
response_model=TripoTaskResponse,
data=request,
)
return await poll_until_finished(cls, response, average_duration=60)
class TripoP1MultiviewToModelNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoP1MultiviewToModelNode",
display_name="Tripo P1: Multiview to Model",
category="partner/3d/Tripo",
description="Tripo P1 multiview-to-3D from 2-4 reference images in [front, left, back, right] order. "
"Front is required; any combination of the other three may be omitted.",
inputs=[
IO.Image.Input("image", tooltip="Front view (0°). Required."),
IO.Image.Input(
"image_left",
optional=True,
tooltip="Left view (90°), i.e. the subject's left side.",
),
IO.Image.Input("image_back", optional=True, tooltip="Back view (180°)."),
IO.Image.Input(
"image_right",
optional=True,
tooltip="Right view (270°), i.e. the subject's right side.",
),
_build_p1_output_mode(include_image_alignment=True),
*_p1_common_inputs(),
],
outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["output_mode", "output_mode.texture_quality"]),
expr=_p1_price_expr(geometry_credits=40, textured_credits=50, detailed_credits=60),
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
output_mode: dict,
image_left: Input.Image | None = None,
image_back: Input.Image | None = None,
image_right: Input.Image | None = None,
face_limit: int = -1,
model_seed: int | None = None,
auto_size: bool = False,
export_uv: bool = True,
compress_geometry: bool = False,
) -> IO.NodeOutput:
views = [image, image_left, image_back, image_right]
if sum(1 for v in views if v is not None) < 2:
raise RuntimeError("Tripo P1 multiview requires at least 2 images (front plus one of left/back/right).")
files: list[TripoFileReference] = []
for view in views:
if view is None:
files.append(TripoFileReference(root=TripoFileEmptyReference()))
continue
url = (await upload_images_to_comfyapi(cls, view, max_images=1))[0]
files.append(TripoFileReference(root=TripoUrlReference(url=url, type="jpeg")))
common = _build_p1_request_kwargs(
output_mode=output_mode,
face_limit=face_limit,
model_seed=model_seed,
auto_size=auto_size,
export_uv=export_uv,
compress_geometry=compress_geometry,
)
request = TripoP1MultiviewToModelRequest(files=files, **common)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
response_model=TripoTaskResponse,
data=request,
)
return await poll_until_finished(cls, response, average_duration=80)
class TripoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -855,6 +1289,9 @@ class TripoExtension(ComfyExtension):
TripoTextToModelNode,
TripoImageToModelNode,
TripoMultiviewToModelNode,
TripoP1TextToModelNode,
TripoP1ImageToModelNode,
TripoP1MultiviewToModelNode,
TripoTextureNode,
TripoRefineNode,
TripoRigNode,

View File

@ -45,7 +45,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="VeoVideoGenerationNode",
display_name="Google Veo 2 Video Generation",
category="video/partner/Veo",
category="partner/video/Veo",
description="Generates videos from text prompts using Google's Veo 2 API",
inputs=[
IO.String.Input(
@ -256,7 +256,7 @@ class Veo3VideoGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="Veo3VideoGenerationNode",
display_name="Google Veo 3 Video Generation",
category="video/partner/Veo",
category="partner/video/Veo",
description="Generates videos from text prompts using Google's Veo 3 API",
inputs=[
IO.String.Input(
@ -468,7 +468,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="Veo3FirstLastFrameNode",
display_name="Google Veo 3 First-Last-Frame to Video",
category="video/partner/Veo",
category="partner/video/Veo",
description="Generate video using prompt and first and last frames.",
inputs=[
IO.String.Input(

View File

@ -71,7 +71,7 @@ class ViduTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduTextToVideoNode",
display_name="Vidu Text To Video Generation",
category="video/partner/Vidu",
category="partner/video/Vidu",
description="Generate video from a text prompt",
inputs=[
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
@ -169,7 +169,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduImageToVideoNode",
display_name="Vidu Image To Video Generation",
category="video/partner/Vidu",
category="partner/video/Vidu",
description="Generate video from image and optional prompt",
inputs=[
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
@ -273,7 +273,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduReferenceVideoNode",
display_name="Vidu Reference To Video Generation",
category="video/partner/Vidu",
category="partner/video/Vidu",
description="Generate video from multiple images and a prompt",
inputs=[
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
@ -388,7 +388,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduStartEndToVideoNode",
display_name="Vidu Start End To Video Generation",
category="video/partner/Vidu",
category="partner/video/Vidu",
description="Generate a video from start and end frames and a prompt",
inputs=[
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
@ -492,7 +492,7 @@ class Vidu2TextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu2TextToVideoNode",
display_name="Vidu2 Text-to-Video Generation",
category="video/partner/Vidu",
category="partner/video/Vidu",
description="Generate video from a text prompt",
inputs=[
IO.Combo.Input("model", options=["viduq2"]),
@ -584,7 +584,7 @@ class Vidu2ImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu2ImageToVideoNode",
display_name="Vidu2 Image-to-Video Generation",
category="video/partner/Vidu",
category="partner/video/Vidu",
description="Generate a video from an image and an optional prompt.",
inputs=[
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
@ -714,7 +714,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu2ReferenceVideoNode",
display_name="Vidu2 Reference-to-Video Generation",
category="video/partner/Vidu",
category="partner/video/Vidu",
description="Generate a video from multiple reference images and a prompt.",
inputs=[
IO.Combo.Input("model", options=["viduq2"]),
@ -849,7 +849,7 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu2StartEndToVideoNode",
display_name="Vidu2 Start/End Frame-to-Video Generation",
category="video/partner/Vidu",
category="partner/video/Vidu",
description="Generate a video from a start frame, an end frame, and a prompt.",
inputs=[
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
@ -969,7 +969,7 @@ class ViduExtendVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduExtendVideoNode",
display_name="Vidu Video Extension",
category="video/partner/Vidu",
category="partner/video/Vidu",
description="Extend an existing video by generating additional frames.",
inputs=[
IO.DynamicCombo.Input(
@ -1138,7 +1138,7 @@ class ViduMultiFrameVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduMultiFrameVideoNode",
display_name="Vidu Multi-Frame Video Generation",
category="video/partner/Vidu",
category="partner/video/Vidu",
description="Generate a video with multiple keyframe transitions.",
inputs=[
IO.Combo.Input("model", options=["viduq2-pro", "viduq2-turbo"]),
@ -1284,7 +1284,7 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu3TextToVideoNode",
display_name="Vidu Q3 Text-to-Video Generation",
category="video/partner/Vidu",
category="partner/video/Vidu",
description="Generate video from a text prompt.",
inputs=[
IO.DynamicCombo.Input(
@ -1429,7 +1429,7 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu3ImageToVideoNode",
display_name="Vidu Q3 Image-to-Video Generation",
category="video/partner/Vidu",
category="partner/video/Vidu",
description="Generate a video from an image and an optional prompt.",
inputs=[
IO.DynamicCombo.Input(
@ -1571,7 +1571,7 @@ class Vidu3StartEndToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu3StartEndToVideoNode",
display_name="Vidu Q3 Start/End Frame-to-Video Generation",
category="video/partner/Vidu",
category="partner/video/Vidu",
description="Generate a video from a start frame, an end frame, and a prompt.",
inputs=[
IO.DynamicCombo.Input(

View File

@ -61,7 +61,7 @@ class WanTextToImageApi(IO.ComfyNode):
return IO.Schema(
node_id="WanTextToImageApi",
display_name="Wan Text to Image",
category="image/partner/Wan",
category="partner/image/Wan",
description="Generates an image based on a text prompt.",
inputs=[
IO.Combo.Input(
@ -184,7 +184,7 @@ class WanImageToImageApi(IO.ComfyNode):
return IO.Schema(
node_id="WanImageToImageApi",
display_name="Wan Image to Image",
category="image/partner/Wan",
category="partner/image/Wan",
description="Generates an image from one or two input images and a text prompt. "
"The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).",
inputs=[
@ -312,7 +312,7 @@ class WanTextToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="WanTextToVideoApi",
display_name="Wan Text to Video",
category="video/partner/Wan",
category="partner/video/Wan",
description="Generates a video based on a text prompt.",
inputs=[
IO.Combo.Input(
@ -495,7 +495,7 @@ class WanImageToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="WanImageToVideoApi",
display_name="Wan Image to Video",
category="video/partner/Wan",
category="partner/video/Wan",
description="Generates a video from the first frame and a text prompt.",
inputs=[
IO.Combo.Input(
@ -674,7 +674,7 @@ class WanReferenceVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="WanReferenceVideoApi",
display_name="Wan Reference to Video",
category="video/partner/Wan",
category="partner/video/Wan",
description="Use the character and voice from input videos, combined with a prompt, "
"to generate a new video that maintains character consistency.",
inputs=[
@ -828,7 +828,7 @@ class Wan2TextToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="Wan2TextToVideoApi",
display_name="Wan 2.7 Text to Video",
category="video/partner/Wan",
category="partner/video/Wan",
description="Generates a video based on a text prompt using the Wan 2.7 model.",
inputs=[
IO.DynamicCombo.Input(
@ -981,7 +981,7 @@ class Wan2ImageToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="Wan2ImageToVideoApi",
display_name="Wan 2.7 Image to Video",
category="video/partner/Wan",
category="partner/video/Wan",
description="Generate a video from a first-frame image, with optional last-frame image and audio.",
inputs=[
IO.DynamicCombo.Input(
@ -1152,7 +1152,7 @@ class Wan2VideoContinuationApi(IO.ComfyNode):
return IO.Schema(
node_id="Wan2VideoContinuationApi",
display_name="Wan 2.7 Video Continuation",
category="video/partner/Wan",
category="partner/video/Wan",
description="Continue a video from where it left off, with optional last-frame control.",
inputs=[
IO.DynamicCombo.Input(
@ -1319,7 +1319,7 @@ class Wan2VideoEditApi(IO.ComfyNode):
return IO.Schema(
node_id="Wan2VideoEditApi",
display_name="Wan 2.7 Video Edit",
category="video/partner/Wan",
category="partner/video/Wan",
description="Edit a video using text instructions, reference images, or style transfer.",
inputs=[
IO.DynamicCombo.Input(
@ -1477,7 +1477,7 @@ class Wan2ReferenceVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="Wan2ReferenceVideoApi",
display_name="Wan 2.7 Reference to Video",
category="video/partner/Wan",
category="partner/video/Wan",
description="Generate a video featuring a person or object from reference materials. "
"Supports single-character performances and multi-character interactions.",
inputs=[
@ -1651,7 +1651,7 @@ class HappyHorseTextToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="HappyHorseTextToVideoApi",
display_name="HappyHorse Text to Video",
category="video/partner/Wan",
category="partner/video/Wan",
description="Generates a video based on a text prompt using the HappyHorse model.",
inputs=[
IO.DynamicCombo.Input(
@ -1775,7 +1775,7 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="HappyHorseImageToVideoApi",
display_name="HappyHorse Image to Video",
category="video/partner/Wan",
category="partner/video/Wan",
description="Generate a video from a first-frame image using the HappyHorse model.",
inputs=[
IO.DynamicCombo.Input(
@ -1905,7 +1905,7 @@ class HappyHorseVideoEditApi(IO.ComfyNode):
return IO.Schema(
node_id="HappyHorseVideoEditApi",
display_name="HappyHorse Video Edit",
category="video/partner/Wan",
category="partner/video/Wan",
description="Edit a video using text instructions or reference images with the HappyHorse model. "
"Output duration is 3-15s and matches the input video; inputs longer than 15s are truncated.",
inputs=[
@ -2046,7 +2046,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="HappyHorseReferenceVideoApi",
display_name="HappyHorse Reference to Video",
category="video/partner/Wan",
category="partner/video/Wan",
description="Generate a video featuring a person or object from reference materials with the HappyHorse "
"model. Supports single-character performances and multi-character interactions.",
inputs=[

View File

@ -27,7 +27,7 @@ class WavespeedFlashVSRNode(IO.ComfyNode):
return IO.Schema(
node_id="WavespeedFlashVSRNode",
display_name="FlashVSR Video Upscale",
category="video/partner/WaveSpeed",
category="partner/video/WaveSpeed",
description="Fast, high-quality video upscaler that "
"boosts resolution and restores clarity for low-resolution or blurry footage.",
inputs=[
@ -98,7 +98,7 @@ class WavespeedImageUpscaleNode(IO.ComfyNode):
return IO.Schema(
node_id="WavespeedImageUpscaleNode",
display_name="WaveSpeed Image Upscale",
category="image/partner/WaveSpeed",
category="partner/image/WaveSpeed",
description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.",
inputs=[
IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]),

View File

@ -469,6 +469,11 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input
input_container = None
output_container = None
# get_stream_source() is untrimmed, so apply the trim window in this same pass.
# start_time is normalized (>= 0); duration == 0 means "until the end".
start_time, duration = video.get_active_trim_window()
trimming = bool(start_time or duration)
try:
input_source = video.get_stream_source()
input_container = av.open(input_source, mode="r")
@ -487,16 +492,45 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input
audio_stream.layout = stream.layout
break
in_video = input_container.streams.video[0]
start_pts = int(start_time / in_video.time_base) if trimming else 0
end_pts = int((start_time + duration) / in_video.time_base) if duration else None
if start_pts:
input_container.seek(start_pts, stream=in_video)
encoded = 0
for frame in input_container.decode(video=0):
if trimming:
if frame.pts is None or frame.pts < start_pts:
continue
if end_pts is not None and frame.pts >= end_pts:
break
frame = frame.reformat(width=out_w, height=out_h, format="yuv420p")
# Re-wrap as a fresh frame: dropping irregular source timestamps (VFR/AVI/GIF/...)
# lets the encoder assign clean ones and avoids mp4 muxer errors.
frame = av.VideoFrame.from_ndarray(frame.to_ndarray(format="yuv420p"), format="yuv420p")
for packet in video_stream.encode(frame):
output_container.mux(packet)
encoded += 1
for packet in video_stream.encode():
output_container.mux(packet)
if encoded == 0:
raise ValueError(
f"resize produced no frames (start_time={start_time}, duration={duration} "
"selected nothing from the source)"
)
if audio_stream is not None:
input_container.seek(0)
for audio_frame in input_container.decode(audio=0):
if trimming:
if audio_frame.time is None or audio_frame.time < start_time:
continue
if duration and audio_frame.time > start_time + duration:
break
# Carry odd audio time bases the mp4 muxer rejects; reset pts, encoder assigns clean ones (MP3-in-AVI)
audio_frame.pts = None
for packet in audio_stream.encode(audio_frame):
output_container.mux(packet)
for packet in audio_stream.encode():

View File

@ -65,6 +65,12 @@ class ChromaRadianceOptions(io.ComfyNode):
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
advanced=True,
),
io.Boolean.Input(
id="force_sequential_txt_ids",
default=False,
tooltip="Force usage of sequential text token IDs instead of zeroes. Should be used for checkpoints from 2026-05-22 to 2026-06-01 that are trained in this way but do not contain the __sequential__ key in the state dict.",
advanced=True,
),
],
outputs=[io.Model.Output()],
)
@ -78,11 +84,15 @@ class ChromaRadianceOptions(io.ComfyNode):
start_sigma: float,
end_sigma: float,
nerf_tile_size: int,
force_sequential_txt_ids: bool,
) -> io.NodeOutput:
radiance_options = {}
if nerf_tile_size >= 0:
radiance_options["nerf_tile_size"] = nerf_tile_size
if force_sequential_txt_ids:
radiance_options["use_sequential_txt_ids"] = True
if not radiance_options:
return io.NodeOutput(model)

File diff suppressed because it is too large Load Diff

View File

@ -47,6 +47,7 @@ class Load3D(IO.ComfyNode):
IO.Load3DCamera.Output(display_name="camera_info"),
IO.Video.Output(display_name="recording_video"),
IO.File3DAny.Output(display_name="model_3d"),
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
],
)
@ -73,7 +74,8 @@ class Load3D(IO.ComfyNode):
if model_file and model_file != "none":
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
mesh_path = model_file
return IO.NodeOutput(output_image, output_mask, mesh_path, normal_image, image['camera_info'], video, file_3d)
model_3d_info = image.get('model_3d_info', [])
return IO.NodeOutput(output_image, output_mask, mesh_path, normal_image, image['camera_info'], video, file_3d, model_3d_info)
process = execute # TODO: remove
@ -122,12 +124,71 @@ class Preview3D(IO.ComfyNode):
process = execute # TODO: remove
class Preview3DAdvanced(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Preview3DAdvanced",
display_name="Preview 3D (Advanced)",
search_aliases=["preview 3d", "3d viewer", "view mesh", "frame 3d", "3d camera output"],
category="3d",
is_experimental=True,
is_output_node=True,
inputs=[
IO.MultiType.Input(
"model_file",
types=[
IO.File3DGLB,
IO.File3DGLTF,
IO.File3DFBX,
IO.File3DOBJ,
IO.File3DSTL,
IO.File3DUSDZ,
IO.File3DAny,
],
tooltip="3D model file from an upstream 3D node.",
),
IO.Load3D.Input("image"),
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True),
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
],
outputs=[
IO.File3DAny.Output(display_name="model_file"),
IO.Load3DCamera.Output(display_name="camera_info"),
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
IO.Int.Output(display_name="width"),
IO.Int.Output(display_name="height"),
],
)
@classmethod
def execute(cls, model_file: Types.File3D, image, width: int, height: int, **kwargs) -> IO.NodeOutput:
filename = f"preview3d_advanced_{uuid.uuid4().hex}.{model_file.format}"
model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename))
camera_info_input = kwargs.get("camera_info", None)
camera_info = camera_info_input if camera_info_input is not None else image['camera_info']
model_3d_info_input = kwargs.get("model_3d_info", None)
model_3d_info = model_3d_info_input if model_3d_info_input is not None else image.get('model_3d_info', [])
return IO.NodeOutput(
model_file,
camera_info,
model_3d_info,
width,
height,
ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info),
)
class Load3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
Load3D,
Preview3D,
Preview3DAdvanced,
]

View File

@ -102,11 +102,18 @@ class MathExpressionNode(io.ComfyNode):
f"Math Expression '{expression}' must evaluate to a numeric result, "
f"got {type(result).__name__}: {result!r}"
)
if not math.isfinite(result):
try:
float_result = float(result)
except OverflowError:
raise ValueError(
f"Math Expression '{expression}' produced a result too large to "
f"represent as a float: {result}"
) from None
if not math.isfinite(float_result):
raise ValueError(
f"Math Expression '{expression}' produced a non-finite result: {result}"
)
return io.NodeOutput(float(result), int(result), bool(result))
return io.NodeOutput(float_result, int(result), bool(result))
class MathExtension(ComfyExtension):

View File

@ -21,8 +21,8 @@ class PiDConditioning(io.ComfyNode):
inputs=[
io.Conditioning.Input("positive"),
io.Latent.Input("latent", tooltip="latent (from VAEEncode or a KSampler)."),
io.Combo.Input("latent_format", options=["flux", "sd3"], default="flux",
tooltip="Flux1 and Flux2 latents auto-detected from channel dim, sd3 has to be selected manually."),
io.Combo.Input("latent_format", options=["flux", "sd3", "sdxl", "qwenimage"], default="flux",
tooltip="Flux1 (16-ch) and Flux2 (128-ch) latents are auto-detected from channel dim under 'flux'. For SD3 (16-ch), SDXL (4-ch), or QwenImage (16-ch), select manually."),
io.Float.Input(
"degrade_sigma", default=0.0, min=0.0, max=1.0, step=0.01,
tooltip="0 = clean latent. Increase to denoise corrupted latent outputs.",
@ -36,9 +36,17 @@ class PiDConditioning(io.ComfyNode):
samples = latent["samples"]
if latent_format == "flux":
fmt_cls = comfy.latent_formats.Flux2 if samples.shape[1] == 128 else comfy.latent_formats.Flux
else:
elif latent_format == "sd3":
fmt_cls = comfy.latent_formats.SD3
elif latent_format == "sdxl":
fmt_cls = comfy.latent_formats.SDXL
elif latent_format == "qwenimage":
fmt_cls = comfy.latent_formats.Wan21
else:
raise ValueError(f"Unknown latent_format: {latent_format}")
lq_latent = fmt_cls().process_in(samples)
if lq_latent.ndim == 5:
lq_latent = lq_latent[:, :, 0]
sigma_t = torch.tensor([float(degrade_sigma)], dtype=torch.float32)
return io.NodeOutput(node_helpers.conditioning_set_values(
positive, {"lq_latent": lq_latent, "degrade_sigma": sigma_t},

View File

@ -16,7 +16,7 @@ from comfy.cli_args import args
from comfy_api.latest import ComfyExtension, IO, Types
def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None):
def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None, unlit=False):
# Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors,
# stashing per-item lengths as runtime attrs so consumers can recover the real slice.
# colors and uvs are 1:1 with vertices, so they're padded to max_vertices and read with vertex_counts.
@ -54,7 +54,7 @@ def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=Non
return Types.MESH(packed_vertices, packed_faces,
uvs=packed_uvs, vertex_colors=packed_colors, texture=texture,
vertex_counts=vertex_counts, face_counts=face_counts)
vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit)
def get_mesh_batch_item(mesh, index):
@ -77,7 +77,7 @@ def get_mesh_batch_item(mesh, index):
def save_glb(vertices, faces, filepath, metadata=None,
uvs=None, vertex_colors=None, texture_image=None):
uvs=None, vertex_colors=None, texture_image=None, unlit=False):
"""
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
@ -234,6 +234,17 @@ def save_glb(vertices, faces, filepath, metadata=None,
textures = []
samplers = []
materials = []
extensions_used = []
if unlit and texture_png_bytes is None:
# Flat, light-independent shading (KHR_materials_unlit): COLOR_0 is shown as-is, matching how a
# gaussian splat renders (emissive). Without this the viewer lights the mesh and washes the colours.
materials.append({
"pbrMetallicRoughness": {"baseColorFactor": [1.0, 1.0, 1.0, 1.0], "metallicFactor": 0.0, "roughnessFactor": 1.0},
"extensions": {"KHR_materials_unlit": {}},
"doubleSided": True,
})
extensions_used.append("KHR_materials_unlit")
primitive["material"] = 0
if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes:
buffer_views.append({
"buffer": 0,
@ -271,6 +282,8 @@ def save_glb(vertices, faces, filepath, metadata=None,
gltf["textures"] = textures
if materials:
gltf["materials"] = materials
if extensions_used:
gltf["extensionsUsed"] = extensions_used
if metadata:
gltf["asset"]["extras"] = metadata
@ -376,7 +389,8 @@ class SaveGLB(IO.ComfyNode):
save_glb(vertices_i, faces_i, os.path.join(full_output_folder, f), metadata,
uvs=uvs_i,
vertex_colors=v_colors,
texture_image=tex_img)
texture_image=tex_img,
unlit=getattr(mesh, "unlit", False))
results.append({
"filename": f,
"subfolder": subfolder,

View File

@ -0,0 +1,270 @@
# TripoSplat nodes: image -> 3D gaussian splat
import logging
import torch
import torch.nn.functional as F
from typing_extensions import override
import comfy.model_management
import comfy.nested_tensor
import comfy.patcher_extension
import comfy.utils
from comfy_api.latest import ComfyExtension, IO, Types
_Q_TOKEN_LENGTH = 8192
_LATENT_CHANNELS = 16
_CAM_CHANNELS = 5
_DINOV3_MEAN = [0.485, 0.456, 0.406]
_DINOV3_STD = [0.229, 0.224, 0.225]
_NUM_GAUSSIANS_MIN = 32768
_NUM_GAUSSIANS_MAX = 1048576
def _preprocess(image: torch.Tensor, mask: torch.Tensor, erode_radius: int, size: int) -> torch.Tensor:
# Match original preprocessing:
# resize min side to `size` -> erode alpha -> alpha bbox -> 1.2x square crop -> resize -> composite on black.
rgb = image[..., :3].clamp(0, 1).movedim(-1, 0) # (3, H, W)
alpha = mask.clamp(0, 1)[None] # (1, H, W)
rgba = torch.cat([rgb, alpha], 0)[None] # (1, 4, H, W)
h, w = rgba.shape[-2:]
s = size / min(w, h)
rgba = comfy.utils.common_upscale(rgba, max(1, round(w * s)), max(1, round(h * s)), "lanczos", "disabled").clamp(0, 1)
a = rgba[:, 3:4]
if erode_radius > 0:
# min filter over a (2r+1) window == morphological erosion of the alpha matte.
a = -F.max_pool2d(-a, 2 * erode_radius + 1, stride=1, padding=erode_radius)
rgba = torch.cat([rgba[:, :3], a], 1)
ys, xs = torch.nonzero(a[0, 0] > 0, as_tuple=True)
if xs.numel() == 0:
raise ValueError("TripoSplatPreprocessImage: mask is empty (no foreground pixels).")
x0, x1 = int(xs.min()), int(xs.max())
y0, y1 = int(ys.min()), int(ys.max())
cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
half = max(x1 - x0, y1 - y0) / 2 * 1.2
left, upper, right, lower = int(cx - half), int(cy - half), int(cx + half), int(cy + half)
H, W = rgba.shape[-2:]
crop = rgba.new_zeros((1, 4, lower - upper, right - left)) # out-of-bounds stays 0, matching PIL.crop
sx0, sy0, sx1, sy1 = max(left, 0), max(upper, 0), min(right, W), min(lower, H)
if sx1 > sx0 and sy1 > sy0:
crop[:, :, sy0 - upper:sy1 - upper, sx0 - left:sx1 - left] = rgba[:, :, sy0:sy1, sx0:sx1]
crop = comfy.utils.common_upscale(crop, size, size, "lanczos", "disabled").clamp(0, 1)
out = (crop[:, :3] * crop[:, 3:4])[0].movedim(0, -1) # composite over black == rgb * alpha
return out.unsqueeze(0) # (1, 1024, 1024, 3)
class TripoSplatPreprocessImage(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoSplatPreprocessImage",
display_name="TripoSplat Preprocess Image",
category="3d/conditioning",
description="Crop center each image to a square canvas on a black background and add padding.",
inputs=[
IO.Image.Input("image"),
IO.Mask.Input("mask"),
IO.Int.Input("erode_radius", default=1, min=0, max=16,
tooltip="Erode the alpha matte by this pixel radius before cropping (avoids border bleed)."),
IO.Int.Input("size", default=1024, min=256, max=4096, step=16,
tooltip="Square image size. The model is trained at 1024; other sizes run but are off-distribution."),
],
outputs=[IO.Image.Output(display_name="image")],
)
@classmethod
def execute(cls, image, mask, erode_radius, size) -> IO.NodeOutput:
size = max(16, (int(size) // 16) * 16) # DINOv3 patch / Flux2 VAE stride is 16
if mask.shape[0] != image.shape[0]:
mask = comfy.utils.repeat_to_batch_size(mask, image.shape[0])
if tuple(mask.shape[1:]) != tuple(image.shape[1:3]):
mask = F.interpolate(mask[:, None].float(), size=tuple(image.shape[1:3]), mode="bilinear", align_corners=False)[:, 0]
prepared = torch.cat([_preprocess(image[i], mask[i], erode_radius, size) for i in range(image.shape[0])], dim=0)
return IO.NodeOutput(prepared)
class TripoSplatConditioning(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoSplatConditioning",
display_name="TripoSplat Conditioning",
category="3d/conditioning",
description="Encode the image with DINOv3 and the Flux2 VAE into TripoSplat positive/negative "
"conditioning, and create the fixed size noise target (latent + camera) for the KSampler",
inputs=[
IO.ClipVision.Input("clip_vision", tooltip="DINOv3 ViT-H/16+ image encoder"),
IO.Vae.Input("vae", tooltip="Flux2 VAE"),
IO.Image.Input("image"),
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
IO.Conditioning.Output(display_name="negative"),
IO.Latent.Output(display_name="latent", tooltip="The fixed size noise target (latent +camera)."),
],
)
@classmethod
def execute(cls, clip_vision, vae, image) -> IO.NodeOutput:
# feature1: DINOv3 token sequence (cls + registers + patches), ImageNet-normalized, with a final non-affine layer norm on top
comfy.model_management.load_model_gpu(clip_vision.patcher)
device = clip_vision.load_device
img = image.movedim(-1, 1).to(device) # (B,3,H,W) in [0,1]
mean = torch.tensor(_DINOV3_MEAN, device=device).view(1, 3, 1, 1)
std = torch.tensor(_DINOV3_STD, device=device).view(1, 3, 1, 1)
img = (img - mean) / std
seq = clip_vision.model(pixel_values=img.float())[0]
feature1 = F.layer_norm(seq.float(), seq.shape[-1:]).to(comfy.model_management.intermediate_device())
# Second conditioning: the Flux2 VAE latent of the image, carried as a standard reference_latents entry
ref = vae.encode(image).to(comfy.model_management.intermediate_device()) # (B, 128, H, W)
b = ref.shape[0]
positive = [[feature1, {"reference_latents": [ref]}]]
negative = [[torch.zeros_like(feature1), {"reference_latents": [torch.zeros_like(ref)]}]]
# Fixed noise target: the latent is a constant-shape (8192, 16) shape-code + a (1, 5) camera token
dev = comfy.model_management.intermediate_device()
latent_seq = torch.zeros([b, _Q_TOKEN_LENGTH, _LATENT_CHANNELS], device=dev)
camera = torch.zeros([b, 1, _CAM_CHANNELS], device=dev)
samples = comfy.nested_tensor.NestedTensor((latent_seq, camera))
return IO.NodeOutput(positive, negative, {"samples": samples})
class VAEDecodeTripoSplat(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VAEDecodeTripoSplat",
display_name="TripoSplat Decode",
category="3d/latent",
description="Decode the sampled TripoSplat latent into a 3D gaussian splat. "
"Modify the number of gaussians to vary the density.",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"),
IO.Int.Input("num_gaussians", default=262144, min=_NUM_GAUSSIANS_MIN, max=_NUM_GAUSSIANS_MAX, step=32,
tooltip="Number of gaussians to produce (rounded to a multiple of 32). "
"262144 matches the octree's point density; higher oversamples the same points "
"(denser, but no new detail) and costs proportionally more VRAM/time."),
IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff,
tooltip="Seeds the octree point sampler (global RNG) for deterministic decodes."),
],
outputs=[IO.Splat.Output(display_name="splat")],
)
@classmethod
def execute(cls, samples, vae, num_gaussians, seed) -> IO.NodeOutput:
s = samples["samples"]
latent = s.unbind()[0] if getattr(s, "is_nested", False) else s # take the latent stream, drop camera
decoder = vae.first_stage_model
gpp = decoder.gaussians_per_point
n = max(_NUM_GAUSSIANS_MIN, min(_NUM_GAUSSIANS_MAX, int(num_gaussians)))
if n % gpp != 0:
n = round(n / gpp) * gpp
dtype_size = comfy.model_management.dtype_size(vae.vae_dtype)
hidden = decoder.gs.model_channels
cond_tokens = latent.shape[1]
memory_required = (cond_tokens * 4 + (n // gpp) * 10) * hidden * dtype_size
comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required)
latent = latent.to(device=vae.device, dtype=vae.vae_dtype)
generator = torch.Generator(device="cpu").manual_seed(seed)
parts = [g.render_tensors() for g in decoder.decode(latent, num_gaussians=n, generator=generator)]
positions, scales, rotations, opacities, sh = (torch.stack(t) for t in zip(*parts))
return IO.NodeOutput(Types.SPLAT(positions, scales, rotations, opacities, sh))
class TripoSplatSamplingPreview(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoSplatSamplingPreview",
display_name="TripoSplat Sampling Preview",
category="3d/latent",
description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded "
"gaussian splat preview at each step.",
inputs=[
IO.Model.Input("model"),
IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"),
IO.Int.Input("octree_level", default=5, min=2, max=8, advanced=True,
tooltip="Octree depth for the preview decode (lower = cheaper/coarser)."),
IO.Int.Input("num_gaussians", default=16384, min=1024, max=262144, step=32,
tooltip="Number of gaussians to produce for the preview (rounded to a multiple of 32)."),
IO.Float.Input("yaw", default=90.0, min=-360.0, max=360.0, step=1.0, tooltip="Preview camera yaw in degrees.", advanced=True,),
IO.Float.Input("pitch", default=15.0, min=-89.0, max=89.0, step=1.0, tooltip="Preview camera pitch in degrees.", advanced=True,),
IO.Int.Input("point_size", default=3, min=1, max=16,
tooltip="Maximum splat radius in pixels. Each gaussian is sized from its scale and capped here; "
"lower = finer/pointier, higher = chunkier."),
],
outputs=[IO.Model.Output()],
)
@classmethod
def execute(cls, model, vae, octree_level, num_gaussians, yaw, pitch, point_size) -> IO.NodeOutput:
from comfy.ldm.triposplat.preview import decode_x0_to_image
cfg = {"gaussians": num_gaussians, "level": octree_level, "yaw": yaw, "pitch": pitch,
"point_size": point_size}
fsm = vae.first_stage_model
cond_tokens = model.model.diffusion_model.q_token_length
memory_required = (cond_tokens * 4 + (num_gaussians // fsm.gaussians_per_point) * 10) * fsm.gs.model_channels * comfy.model_management.dtype_size(vae.vae_dtype)
# Live preview via WrappersMP.OUTER_SAMPLE + ProgressBar
# The wrapper augments the sampler's own callback to decode x0 -> gaussian splat -> preview image each step
def outer_sample_wrapper(executor, *args, **kwargs):
args = list(args)
cb_idx = 5 # outer_sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
orig_cb = args[cb_idx] if len(args) > cb_idx else kwargs.get("callback")
state = {"ok": True, "pbar": None, "loaded": False}
def callback(step, x0, x, total_steps):
if orig_cb is not None:
orig_cb(step, x0, x, total_steps)
if not state["ok"]:
return
try:
if not state["loaded"]:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
loaded_models.append(vae.patcher)
comfy.model_management.load_models_gpu(loaded_models, memory_required=memory_required)
state["loaded"] = True
img = decode_x0_to_image(vae, x0, cfg)
if state["pbar"] is None:
state["pbar"] = comfy.utils.ProgressBar(total_steps)
state["pbar"].update_absolute(step + 1, total_steps, ("JPEG", img, 512))
except Exception as e:
logging.warning("TripoSplatSamplingPreview: preview failed, disabling ({})".format(e))
state["ok"] = False
if len(args) > cb_idx:
args[cb_idx] = callback
else:
kwargs["callback"] = callback
return executor(*args, **kwargs)
m = model.clone()
m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "triposplat_sampling_preview", outer_sample_wrapper)
return IO.NodeOutput(m)
class TripoSplatExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
TripoSplatPreprocessImage,
TripoSplatConditioning,
VAEDecodeTripoSplat,
TripoSplatSamplingPreview,
]
async def comfy_entrypoint() -> TripoSplatExtension:
return TripoSplatExtension()

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.22.0"
__version__ = "0.23.0"

View File

@ -464,13 +464,6 @@ def start_comfyui(asyncio_loop=None):
folder_paths.set_temp_directory(temp_dir)
cleanup_temp()
if args.windows_standalone_build:
try:
import new_updater
new_updater.update_windows_updater()
except:
pass
if not asyncio_loop:
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)

View File

@ -1,35 +0,0 @@
import os
import shutil
base_path = os.path.dirname(os.path.realpath(__file__))
def update_windows_updater():
top_path = os.path.dirname(base_path)
updater_path = os.path.join(base_path, ".ci/update_windows/update.py")
bat_path = os.path.join(base_path, ".ci/update_windows/update_comfyui.bat")
dest_updater_path = os.path.join(top_path, "update/update.py")
dest_bat_path = os.path.join(top_path, "update/update_comfyui.bat")
dest_bat_deps_path = os.path.join(top_path, "update/update_comfyui_and_python_dependencies.bat")
try:
with open(dest_bat_path, 'rb') as f:
contents = f.read()
except:
return
if not contents.startswith(b"..\\python_embeded\\python.exe .\\update.py"):
return
shutil.copy(updater_path, dest_updater_path)
try:
with open(dest_bat_deps_path, 'rb') as f:
contents = f.read()
contents = contents.replace(b'..\\python_embeded\\python.exe .\\update.py ..\\ComfyUI\\', b'call update_comfyui.bat nopause')
with open(dest_bat_deps_path, 'wb') as f:
f.write(contents)
except:
pass
shutil.copy(bat_path, dest_bat_path)
print("Updated the windows standalone package updater.") # noqa: T201

View File

@ -2477,6 +2477,8 @@ async def init_builtin_extra_nodes():
"nodes_save_3d.py",
"nodes_moge.py",
"nodes_mediapipe.py",
"nodes_gaussian_splat.py",
"nodes_triposplat.py"
]
import_failed = []

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.22.0"
version = "0.23.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.44.19
comfyui-workflow-templates==0.9.85
comfyui-embedded-docs==0.5.1
comfyui-workflow-templates==0.9.92
comfyui-embedded-docs==0.5.2
torch
torchsde
torchvision
@ -22,8 +22,8 @@ alembic
SQLAlchemy>=2.0.0
filelock
av>=16.0.0
comfy-kitchen==0.2.9
comfy-aimdo==0.4.5
comfy-kitchen==0.2.10
comfy-aimdo==0.4.8
requests
simpleeval>=1.0.0
blake3

View File

@ -197,3 +197,10 @@ class TestMathExpressionExecute:
def test_pow_huge_exponent_raises(self):
with pytest.raises(ValueError, match="Exponent .* exceeds maximum"):
self._exec("pow(a, b)", a=10, b=10000000)
def test_huge_int_result_raises_value_error(self):
# Exponent is within the allowed MAX_EXPONENT range, so the result is a
# finite Python int that is nonetheless too large to convert to float.
# This must raise a clean ValueError, not an uncaught OverflowError.
with pytest.raises(ValueError, match="too large to represent as a float"):
self._exec("2 ** 3999")