mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-24 00:39:30 +08:00
Merge branch 'Comfy-Org:master' into master
This commit is contained in:
commit
54e1802b0e
@ -105,7 +105,7 @@ class WindowAttention(nn.Module):
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
relative_position_bias = comfy.ops.cast_to_input(relative_position_bias.permute(2, 0, 1).contiguous(), attn) # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
|
||||
@ -9,6 +9,7 @@ import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.clip_model
|
||||
import comfy.image_encoders.dino2
|
||||
import comfy.image_encoders.dino3
|
||||
|
||||
class Output:
|
||||
def __getitem__(self, key):
|
||||
@ -23,12 +24,16 @@ IMAGE_ENCODERS = {
|
||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
"dinov3": comfy.image_encoders.dino3.DINOv3ViTModel,
|
||||
}
|
||||
|
||||
class ClipVisionModel():
|
||||
def __init__(self, json_config):
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
if isinstance(json_config, dict):
|
||||
config = json_config
|
||||
else:
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
@ -134,6 +139,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
|
||||
elif 'layer.0.mlp.gate_proj.weight' in sd and 'layer.31.norm1.weight' in sd: # Dinov3 ViT-H/16+ (SwiGLU gated MLP, 32 layers)
|
||||
json_config = comfy.image_encoders.dino3.DINOV3_VITH_CONFIG
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
259
comfy/image_encoders/dino3.py
Normal file
259
comfy/image_encoders/dino3.py
Normal file
@ -0,0 +1,259 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
|
||||
|
||||
|
||||
# DINOv3 ViT-H/16+ (SwiGLU)
|
||||
DINOV3_VITH_CONFIG = {
|
||||
"model_type": "dinov3",
|
||||
"num_hidden_layers": 32,
|
||||
"hidden_size": 1280,
|
||||
"num_attention_heads": 20,
|
||||
"num_register_tokens": 4,
|
||||
"intermediate_size": 5120,
|
||||
"layer_norm_eps": 1e-5,
|
||||
"num_channels": 3,
|
||||
"patch_size": 16,
|
||||
"rope_theta": 100.0,
|
||||
"use_gated_mlp": True,
|
||||
"gated_mlp_act": "silu",
|
||||
"image_size": 1024,
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225],
|
||||
}
|
||||
|
||||
|
||||
class DINOv3ViTMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.act_fn = torch.nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.up_proj(x)))
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, **kwargs):
|
||||
num_tokens = q.shape[-2]
|
||||
num_patches = sin.shape[-2]
|
||||
num_prefix_tokens = num_tokens - num_patches
|
||||
|
||||
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
|
||||
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
|
||||
|
||||
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
|
||||
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
|
||||
|
||||
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
|
||||
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
|
||||
|
||||
return q, k
|
||||
|
||||
|
||||
class DINOv3ViTAttention(nn.Module):
|
||||
def __init__(self, hidden_size, num_attention_heads, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.embed_dim = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
|
||||
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, position_embeddings=None, **kwargs):
|
||||
batch_size, patches, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
attn = optimized_attention_for_device(query_states.device, mask=False)
|
||||
attn_output = attn(
|
||||
query_states, key_states, value_states, self.num_heads, attention_mask,
|
||||
skip_reshape=True, skip_output_reshape=True, low_precision_attention=False,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class DINOv3ViTGatedMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations, act="silu"):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.act_fn = torch.nn.SiLU() if act == "silu" else torch.nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
def get_patches_center_coordinates(num_patches_h, num_patches_w, dtype, device):
|
||||
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
|
||||
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
|
||||
coords_h = coords_h / num_patches_h
|
||||
coords_w = coords_w / num_patches_w
|
||||
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
|
||||
coords = coords.flatten(0, 1)
|
||||
coords = 2.0 * coords - 1.0
|
||||
return coords
|
||||
|
||||
|
||||
class DINOv3ViTRopePositionEmbedding(nn.Module):
|
||||
inv_freq: torch.Tensor
|
||||
|
||||
def __init__(self, rope_theta, hidden_size, num_attention_heads, patch_size, device, dtype):
|
||||
super().__init__()
|
||||
self.base = rope_theta
|
||||
self.head_dim = hidden_size // num_attention_heads
|
||||
self.patch_size = patch_size
|
||||
|
||||
inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
_, _, height, width = pixel_values.shape
|
||||
num_patches_h = height // self.patch_size
|
||||
num_patches_w = width // self.patch_size
|
||||
|
||||
patch_coords = get_patches_center_coordinates(num_patches_h, num_patches_w, dtype=torch.float32, device=pixel_values.device)
|
||||
self.inv_freq = self.inv_freq.to(pixel_values.device)
|
||||
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
|
||||
angles = angles.flatten(1, 2)
|
||||
angles = angles.tile(2)
|
||||
cos = torch.cos(angles).to(dtype=pixel_values.dtype)
|
||||
sin = torch.sin(angles).to(dtype=pixel_values.dtype)
|
||||
return cos, sin
|
||||
|
||||
|
||||
class DINOv3ViTEmbeddings(nn.Module):
|
||||
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
|
||||
self.mask_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
|
||||
self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype))
|
||||
self.patch_embeddings = operations.Conv2d(
|
||||
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, pixel_values, bool_masked_pos=None):
|
||||
batch_size = pixel_values.shape[0]
|
||||
|
||||
patch_embeddings = self.patch_embeddings(pixel_values)
|
||||
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
|
||||
|
||||
if bool_masked_pos is not None:
|
||||
mask_token = comfy.ops.cast_to_input(self.mask_token, patch_embeddings)
|
||||
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
|
||||
|
||||
cls_token = comfy.ops.cast_to_input(self.cls_token.expand(batch_size, -1, -1), patch_embeddings)
|
||||
register_tokens = comfy.ops.cast_to_input(self.register_tokens.expand(batch_size, -1, -1), patch_embeddings)
|
||||
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
|
||||
return embeddings
|
||||
|
||||
|
||||
class DINOv3ViTLayer(nn.Module):
|
||||
def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size,
|
||||
num_attention_heads, device, dtype, operations, gated_mlp_act="silu"):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||
self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
|
||||
|
||||
self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||
if use_gated_mlp:
|
||||
self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations, act=gated_mlp_act)
|
||||
else:
|
||||
self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations)
|
||||
self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.attention(hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings)
|
||||
hidden_states = self.layer_scale1(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.layer_scale2(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DINOv3ViTModel(nn.Module):
|
||||
def __init__(self, config, dtype, device, operations):
|
||||
super().__init__()
|
||||
num_hidden_layers = config["num_hidden_layers"]
|
||||
hidden_size = config["hidden_size"]
|
||||
num_attention_heads = config["num_attention_heads"]
|
||||
num_register_tokens = config["num_register_tokens"]
|
||||
intermediate_size = config["intermediate_size"]
|
||||
layer_norm_eps = config["layer_norm_eps"]
|
||||
num_channels = config["num_channels"]
|
||||
patch_size = config["patch_size"]
|
||||
rope_theta = config["rope_theta"]
|
||||
use_gated_mlp = config.get("use_gated_mlp", False)
|
||||
gated_mlp_act = config.get("gated_mlp_act", "silu")
|
||||
|
||||
self.embeddings = DINOv3ViTEmbeddings(
|
||||
hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(
|
||||
rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device
|
||||
)
|
||||
self.layer = nn.ModuleList([
|
||||
DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=use_gated_mlp, mlp_bias=True,
|
||||
intermediate_size=intermediate_size, num_attention_heads=num_attention_heads,
|
||||
dtype=dtype, device=device, operations=operations, gated_mlp_act=gated_mlp_act)
|
||||
for _ in range(num_hidden_layers)])
|
||||
self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
def forward(self, pixel_values, bool_masked_pos=None, **kwargs):
|
||||
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
for layer_module in self.layer:
|
||||
hidden_states = layer_module(hidden_states, position_embeddings=position_embeddings)
|
||||
|
||||
if kwargs.get("skip_norm_elementwise", False):
|
||||
sequence_output = F.layer_norm(hidden_states, hidden_states.shape[-1:])
|
||||
else:
|
||||
norm = self.norm.to(hidden_states.device)
|
||||
sequence_output = norm(hidden_states)
|
||||
pooled_output = sequence_output[:, 0, :]
|
||||
return sequence_output, None, pooled_output, None
|
||||
@ -239,6 +239,16 @@ class Flux2(LatentFormat):
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
class TripoSplat(LatentFormat):
|
||||
# Sequence latent (B, 8192, 16) the camera token rides alongside as a second nested latent
|
||||
latent_channels = 16
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
latent_dimensions = 3
|
||||
|
||||
@ -38,6 +38,8 @@ class ChromaRadianceParams(ChromaParams):
|
||||
# None means use the same dtype as the model.
|
||||
nerf_embedder_dtype: Optional[torch.dtype]
|
||||
use_x0: bool
|
||||
# Use sequential txt_ids instead of zeros
|
||||
use_sequential_txt_ids: bool
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
"""
|
||||
@ -162,6 +164,9 @@ class ChromaRadiance(Chroma):
|
||||
if params.use_x0:
|
||||
self.register_buffer("__x0__", torch.tensor([]))
|
||||
|
||||
if params.use_sequential_txt_ids:
|
||||
self.register_buffer("__sequential__", torch.tensor([]))
|
||||
|
||||
@property
|
||||
def _nerf_final_layer(self) -> nn.Module:
|
||||
if self.params.nerf_final_head_type == "linear":
|
||||
@ -313,6 +318,9 @@ class ChromaRadiance(Chroma):
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
# Radiance after 2026-05-22 uses sequential txt_ids instead of zeros
|
||||
if params.use_sequential_txt_ids:
|
||||
txt_ids[:, :, 0] = torch.arange(context.shape[1], device=x.device, dtype=x.dtype).unsqueeze(0).expand(bs, -1)
|
||||
|
||||
img_out = self.forward_orig(
|
||||
img,
|
||||
|
||||
199
comfy/ldm/triposplat/gaussian.py
Normal file
199
comfy/ldm/triposplat/gaussian.py
Normal file
@ -0,0 +1,199 @@
|
||||
# TripoSplat 3D gaussian container. Operates on already-decoded
|
||||
# tensors and exposes them as render-ready tensors (render_tensors) for the generic SPLAT type.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class GaussianModel:
|
||||
def __init__(self, aabb: list, sh_degree: int = 0, mininum_kernel_size: float = 0.0,
|
||||
scaling_bias: float = 0.01, opacity_bias: float = 0.1,
|
||||
scaling_activation: str = "exp", device=None):
|
||||
self.sh_degree = sh_degree
|
||||
self.mininum_kernel_size = mininum_kernel_size
|
||||
self.scaling_bias = scaling_bias
|
||||
self.opacity_bias = opacity_bias
|
||||
self.device = device
|
||||
self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
|
||||
|
||||
if scaling_activation == "exp":
|
||||
self._scaling_activation = torch.exp
|
||||
self._inverse_scaling_activation = torch.log
|
||||
elif scaling_activation == "softplus":
|
||||
self._scaling_activation = F.softplus
|
||||
self._inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x))
|
||||
|
||||
self._opacity_activation = torch.sigmoid
|
||||
self._inverse_opacity_activation = lambda x: torch.log(x / (1 - x))
|
||||
|
||||
self.scale_bias = self._inverse_scaling_activation(torch.tensor(self.scaling_bias)).to(self.device)
|
||||
self.rots_bias = torch.zeros(4, device=self.device)
|
||||
self.rots_bias[0] = 1
|
||||
self.opacity_bias_val = self._inverse_opacity_activation(torch.tensor(self.opacity_bias)).to(self.device)
|
||||
|
||||
self._storage = {}
|
||||
|
||||
def _get_store(self, name):
|
||||
return self._storage.get(name)
|
||||
|
||||
def _set_store(self, name, value):
|
||||
self._storage[name] = value
|
||||
|
||||
@property
|
||||
def _xyz(self):
|
||||
return self._get_store("_xyz")
|
||||
@_xyz.setter
|
||||
def _xyz(self, value):
|
||||
if value is None:
|
||||
self._set_store("_xyz", None)
|
||||
self._set_store("xyz", None)
|
||||
return
|
||||
self._set_store("_xyz", value)
|
||||
self._set_store("xyz", value * self.aabb[None, 3:] + self.aabb[None, :3])
|
||||
|
||||
@property
|
||||
def get_xyz(self):
|
||||
return self._get_store("xyz")
|
||||
|
||||
@property
|
||||
def _features_dc(self):
|
||||
return self._get_store("_features_dc")
|
||||
@_features_dc.setter
|
||||
def _features_dc(self, value):
|
||||
self._set_store("_features_dc", value)
|
||||
|
||||
@property
|
||||
def _opacity(self):
|
||||
return self._get_store("_opacity")
|
||||
@_opacity.setter
|
||||
def _opacity(self, value):
|
||||
if value is None:
|
||||
self._set_store("_opacity", None)
|
||||
self._set_store("opacity", None)
|
||||
return
|
||||
self._set_store("_opacity", value)
|
||||
self._set_store("opacity", self._opacity_activation(value + self.opacity_bias_val))
|
||||
|
||||
@property
|
||||
def get_opacity(self):
|
||||
return self._get_store("opacity")
|
||||
|
||||
@property
|
||||
def _scaling(self):
|
||||
return self._get_store("_scaling")
|
||||
@_scaling.setter
|
||||
def _scaling(self, value):
|
||||
if value is None:
|
||||
self._set_store("_scaling", None)
|
||||
self._set_store("scaling", None)
|
||||
return
|
||||
self._set_store("_scaling", value)
|
||||
s = self._scaling_activation(value + self.scale_bias)
|
||||
s = torch.square(s) + self.mininum_kernel_size ** 2
|
||||
self._set_store("scaling", torch.sqrt(s))
|
||||
|
||||
@property
|
||||
def get_scaling(self):
|
||||
return self._get_store("scaling")
|
||||
|
||||
@property
|
||||
def _rotation(self):
|
||||
return self._get_store("_rotation")
|
||||
@_rotation.setter
|
||||
def _rotation(self, value):
|
||||
self._set_store("_rotation", value)
|
||||
|
||||
_DEFAULT_TRANSFORM = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
|
||||
|
||||
def render_tensors(self):
|
||||
# Render-ready (activated, world-space) tensors for the generic SPLAT type. The axis transform
|
||||
# (a 3x3 rotation, object frame -> viewer Y-up) is baked into positions and rotations.
|
||||
# Returns float tensors on the intermediate device: positions (N,3), scales (N,3) linear,
|
||||
# rotations (N,4) wxyz, opacities (N,1) in [0,1], sh (N,K,3) coefficients.
|
||||
xyz = self.get_xyz.float()
|
||||
scaling = self.get_scaling.float()
|
||||
opacity = self.get_opacity.float()
|
||||
rotation = (self._rotation + self.rots_bias[None, :]).float()
|
||||
sh = self._features_dc.float() # (N, K, 3)
|
||||
T = torch.as_tensor(self._DEFAULT_TRANSFORM, dtype=torch.float32, device=xyz.device)
|
||||
xyz = xyz @ T.T
|
||||
rotation = _matrix_to_quat(torch.matmul(T, _quat_to_matrix(rotation)))
|
||||
rotation = rotation / torch.linalg.norm(rotation, dim=-1, keepdim=True)
|
||||
out_device = comfy.model_management.intermediate_device()
|
||||
return (
|
||||
xyz.to(out_device).contiguous(), scaling.to(out_device).contiguous(),
|
||||
rotation.to(out_device).contiguous(), opacity.to(out_device).contiguous(),
|
||||
sh.to(out_device).contiguous(),
|
||||
)
|
||||
|
||||
|
||||
def _quat_to_matrix(q):
|
||||
q = q / torch.linalg.norm(q, dim=-1, keepdim=True)
|
||||
w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
|
||||
R = torch.stack([
|
||||
1 - 2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y),
|
||||
2*(x*y + w*z), 1 - 2*(x*x + z*z), 2*(y*z - w*x),
|
||||
2*(x*z - w*y), 2*(y*z + w*x), 1 - 2*(x*x + y*y),
|
||||
], dim=-1).reshape(-1, 3, 3)
|
||||
return R
|
||||
|
||||
|
||||
def _matrix_to_quat(R):
|
||||
trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
|
||||
q = torch.zeros((R.shape[0], 4), dtype=R.dtype, device=R.device)
|
||||
s = torch.sqrt(torch.clamp(trace + 1, min=0)) * 2
|
||||
q[:, 0] = 0.25 * s
|
||||
denom = torch.where(s != 0, s, torch.ones_like(s))
|
||||
q[:, 1] = (R[:, 2, 1] - R[:, 1, 2]) / denom
|
||||
q[:, 2] = (R[:, 0, 2] - R[:, 2, 0]) / denom
|
||||
q[:, 3] = (R[:, 1, 0] - R[:, 0, 1]) / denom
|
||||
m01 = (R[:, 0, 0] >= R[:, 1, 1]) & (R[:, 0, 0] >= R[:, 2, 2]) & (s == 0)
|
||||
s1 = torch.sqrt(torch.clamp(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2], min=0)) * 2
|
||||
q[m01, 0] = (R[m01, 2, 1] - R[m01, 1, 2]) / s1[m01]
|
||||
q[m01, 1] = 0.25 * s1[m01]
|
||||
q[m01, 2] = (R[m01, 0, 1] + R[m01, 1, 0]) / s1[m01]
|
||||
q[m01, 3] = (R[m01, 0, 2] + R[m01, 2, 0]) / s1[m01]
|
||||
m11 = (R[:, 1, 1] > R[:, 0, 0]) & (R[:, 1, 1] >= R[:, 2, 2]) & (s == 0)
|
||||
s2 = torch.sqrt(torch.clamp(1 + R[:, 1, 1] - R[:, 0, 0] - R[:, 2, 2], min=0)) * 2
|
||||
q[m11, 0] = (R[m11, 0, 2] - R[m11, 2, 0]) / s2[m11]
|
||||
q[m11, 1] = (R[m11, 0, 1] + R[m11, 1, 0]) / s2[m11]
|
||||
q[m11, 2] = 0.25 * s2[m11]
|
||||
q[m11, 3] = (R[m11, 1, 2] + R[m11, 2, 1]) / s2[m11]
|
||||
m21 = (R[:, 2, 2] > R[:, 0, 0]) & (R[:, 2, 2] > R[:, 1, 1]) & (s == 0)
|
||||
s3 = torch.sqrt(torch.clamp(1 + R[:, 2, 2] - R[:, 0, 0] - R[:, 1, 1], min=0)) * 2
|
||||
q[m21, 0] = (R[m21, 1, 0] - R[m21, 0, 1]) / s3[m21]
|
||||
q[m21, 1] = (R[m21, 0, 2] + R[m21, 2, 0]) / s3[m21]
|
||||
q[m21, 2] = (R[m21, 1, 2] + R[m21, 2, 1]) / s3[m21]
|
||||
q[m21, 3] = 0.25 * s3[m21]
|
||||
return q / torch.linalg.norm(q, dim=-1, keepdim=True)
|
||||
|
||||
|
||||
def build_gaussian_models(decoder, points_pred: dict, pred: dict):
|
||||
# Assemble GaussianModels from the elastic decoder layout. decoder is the ElasticGaussianFixedlenDecoder
|
||||
# (carries layout / rep_config / _get_offset)
|
||||
x = points_pred
|
||||
offset = decoder._get_offset(pred['features'])
|
||||
h = pred["features"]
|
||||
ret = []
|
||||
for i in range(h.shape[0]):
|
||||
g = GaussianModel(
|
||||
sh_degree=0,
|
||||
aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
|
||||
mininum_kernel_size=decoder.rep_config['filter_kernel_size_3d'],
|
||||
scaling_bias=decoder.rep_config['scaling_bias'],
|
||||
opacity_bias=decoder.rep_config['opacity_bias'],
|
||||
scaling_activation=decoder.rep_config['scaling_activation'],
|
||||
device=h.device,
|
||||
)
|
||||
_x = x["points"][i, :, None, :]
|
||||
for k, v in decoder.layout.items():
|
||||
if k == '_xyz':
|
||||
setattr(g, k, (offset[i] + _x).flatten(0, 1))
|
||||
elif k in ('_xyz_center', '_offset_scale'):
|
||||
continue
|
||||
else:
|
||||
feats = h[i][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
|
||||
setattr(g, k, feats * decoder.rep_config['lr'][k])
|
||||
ret.append(g)
|
||||
return ret
|
||||
326
comfy/ldm/triposplat/model.py
Normal file
326
comfy/ldm/triposplat/model.py
Normal file
@ -0,0 +1,326 @@
|
||||
# TripoSplat flow-matching denoiser (LatentSeqMMFlowModel). Registered as a ModelType.FLOW arch and
|
||||
# driven by the standard KSampler; jointly denoises the (B, 8192, 16) latent and a (B, 1, 5) camera token
|
||||
# carried as a 2-element nested latent.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.rmsnorm
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
|
||||
|
||||
class MultiHeadRMSNorm(nn.Module):
|
||||
def __init__(self, dim, heads, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.empty(heads, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x):
|
||||
x = comfy.rmsnorm.rms_norm(x)
|
||||
return x * comfy.model_management.cast_to(self.gamma, x.dtype, x.device)
|
||||
|
||||
|
||||
# Positional embeddings
|
||||
|
||||
class RePo3DRotaryEmbedding(nn.Module):
|
||||
def __init__(self, model_channels, num_heads, head_dim, repo_hidden_ratio=0.125, max_freq=16.0,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
repo_hidden_size = int(model_channels * repo_hidden_ratio)
|
||||
self.norm = operations.LayerNorm(model_channels, dtype=dtype, device=device)
|
||||
self.gate_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device)
|
||||
self.content_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device)
|
||||
self.act = nn.SiLU()
|
||||
self.final_map = operations.Linear(repo_hidden_size, 3 * num_heads, bias=False, dtype=dtype, device=device)
|
||||
self.dim_0 = 2 * (head_dim // 6)
|
||||
self.dim_1 = 2 * (head_dim // 6)
|
||||
self.dim_2 = head_dim - self.dim_0 - self.dim_1
|
||||
dims = [self.dim_0, self.dim_1, self.dim_2]
|
||||
freqs_list = []
|
||||
for d in dims:
|
||||
freq_dim = d // 2
|
||||
freqs_list.append(torch.linspace(1.0, float(max_freq), steps=freq_dim, dtype=torch.float32))
|
||||
self.freqs_0 = nn.Parameter(freqs_list[0])
|
||||
self.freqs_1 = nn.Parameter(freqs_list[1])
|
||||
self.freqs_2 = nn.Parameter(freqs_list[2])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
h = self.norm(hidden_states)
|
||||
feat = self.act(self.gate_map(h)) * self.content_map(h)
|
||||
out = self.final_map(feat)
|
||||
B, L, _ = out.shape
|
||||
delta_pos = out.reshape(B, L, self.num_heads, 3)
|
||||
f0 = comfy.model_management.cast_to(self.freqs_0, torch.float32, out.device)
|
||||
f1 = comfy.model_management.cast_to(self.freqs_1, torch.float32, out.device)
|
||||
f2 = comfy.model_management.cast_to(self.freqs_2, torch.float32, out.device)
|
||||
ang_0 = delta_pos[..., 0].unsqueeze(-1) * f0 * torch.pi
|
||||
ang_1 = delta_pos[..., 1].unsqueeze(-1) * f1 * torch.pi
|
||||
ang_2 = delta_pos[..., 2].unsqueeze(-1) * f2 * torch.pi
|
||||
ang = torch.cat([ang_0, ang_1, ang_2], dim=-1).float() # (B, L, heads, head_dim/2)
|
||||
cos, sin = ang.cos(), ang.sin()
|
||||
return torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*ang.shape, 2, 2)
|
||||
|
||||
|
||||
class PcdAbsolutePositionEmbedder(nn.Module):
|
||||
# Sinusoidal absolute position embedding. Two fixed schedules are used in TripoSplat:
|
||||
# "pow2" (flow-model latent anchors) and "log2" (octree / gaussian decoders).
|
||||
def __init__(self, channels: int, in_channels: int = 3, max_res: int = 16, schedule: str = "pow2"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.in_channels = in_channels
|
||||
self.max_res = max_res
|
||||
self.schedule = schedule
|
||||
self.freq_dim = channels // in_channels // 2
|
||||
|
||||
def _freqs(self, device):
|
||||
if self.schedule == "pow2":
|
||||
freqs_2exp = torch.arange(self.max_res, dtype=torch.float32, device=device)
|
||||
res_dim = max(0, self.freq_dim - self.max_res)
|
||||
freqs_res = (torch.arange(res_dim, dtype=torch.float32, device=device) / max(res_dim, 1) * self.max_res
|
||||
if res_dim > 0 else torch.empty(0, device=device))
|
||||
freqs = torch.cat([freqs_2exp, freqs_res], dim=0)[:self.freq_dim]
|
||||
return torch.pow(2.0, freqs) * 2.0 # *2 folds this schedule's 2*pi into the shared *pi below
|
||||
logs = torch.linspace(0.0, float(self.max_res), steps=self.freq_dim, dtype=torch.float32, device=device)
|
||||
return torch.pow(2.0, logs)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
*dims, D = x.shape
|
||||
out = torch.outer(x.reshape(-1), self._freqs(x.device)) * torch.pi
|
||||
out = torch.cat([out.sin(), out.cos()], dim=-1).reshape(*dims, -1)
|
||||
if out.shape[-1] < self.channels:
|
||||
out = torch.cat([out, torch.zeros(*dims, self.channels - out.shape[-1],
|
||||
device=out.device, dtype=out.dtype)], dim=-1)
|
||||
return out.to(orig_dtype)
|
||||
|
||||
|
||||
def attention(q, k, v, transformer_options=None):
|
||||
# q, k, v: (B, L, heads, dim) -> (B, L, heads, dim). Shared optimized_attention call convention.
|
||||
out = optimized_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), heads=q.shape[2],
|
||||
skip_reshape=True, skip_output_reshape=True, low_precision_attention=False,
|
||||
transformer_options=transformer_options)
|
||||
return out.transpose(1, 2)
|
||||
|
||||
|
||||
# Transformer building blocks
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(in_channels, hidden_channels, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(hidden_channels, out_channels, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
class RopeMultiHeadAttention(nn.Module):
|
||||
def __init__(self, channels, num_heads, qkv_bias=True, qk_rms_norm=False, use_rope=False,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = channels // num_heads
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
self.use_rope = use_rope
|
||||
self.qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
if self.qk_rms_norm:
|
||||
self.q_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
||||
self.k_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
||||
self.out = operations.Linear(channels, channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, rope_emb=None, transformer_options=None):
|
||||
B, L, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim)
|
||||
q, k, v = qkv.unbind(2)
|
||||
if self.use_rope:
|
||||
q, k = apply_rope(q, k, rope_emb)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
h = attention(q, k, v, transformer_options) # (B, L, heads, dim)
|
||||
return self.out(h.reshape(B, L, C))
|
||||
|
||||
|
||||
class UnifiedTransformerBlock(nn.Module):
|
||||
def __init__(self, channels, num_heads, mlp_ratio=4.0,
|
||||
use_rope=False, qk_rms_norm=False, qkv_bias=True,
|
||||
modulation=True, share_mod=False,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.modulation = modulation
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = RopeMultiHeadAttention(channels, num_heads=num_heads,
|
||||
qkv_bias=qkv_bias, use_rope=use_rope, qk_rms_norm=qk_rms_norm,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
|
||||
if modulation:
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
|
||||
self.shift_table = nn.Parameter(torch.empty(1, 6 * channels, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, mod=None, rotary_emb=None, transformer_options=None):
|
||||
if self.modulation:
|
||||
if not self.share_mod:
|
||||
mod = self.adaLN_modulation(mod)
|
||||
mod = mod + comfy.model_management.cast_to(self.shift_table, mod.dtype, mod.device)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
||||
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
|
||||
x = torch.addcmul(x, self.attn(h, rope_emb=rotary_emb, transformer_options=transformer_options), gate_msa.unsqueeze(1))
|
||||
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
|
||||
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
|
||||
else:
|
||||
x = x + self.attn(self.norm1(x), rope_emb=rotary_emb, transformer_options=transformer_options)
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
emb = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
return self.mlp(emb.to(self.mlp[0].weight.dtype))
|
||||
|
||||
|
||||
class LatentSeqMMFlowModel(nn.Module):
|
||||
def __init__(self, image_model=None, q_token_length=8192, in_channels=16, model_channels=1024,
|
||||
cond_channels=1280, out_channels=16, num_blocks=24, num_refiner_blocks=2,
|
||||
num_heads=None, num_head_channels=64, cam_channels=5, cond2_channels=128,
|
||||
mlp_ratio=4, share_mod=True, qk_rms_norm=True,
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.q_token_length = q_token_length
|
||||
self.in_channels = in_channels
|
||||
self.cam_channels = cam_channels
|
||||
self.model_channels = model_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.cond2_channels = cond2_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.num_refiner_blocks = num_refiner_blocks
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.share_mod = share_mod
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
|
||||
factory_kwargs = dict(dtype=dtype, device=device)
|
||||
op_kwargs = dict(operations=operations, **factory_kwargs)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(model_channels, **op_kwargs)
|
||||
if share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, **factory_kwargs))
|
||||
|
||||
self.input_layer = operations.Linear(in_channels, model_channels, **factory_kwargs)
|
||||
self.cond_embedder = operations.Linear(cond_channels, model_channels, **factory_kwargs)
|
||||
self.cond_embedder2 = operations.Linear(cond2_channels, model_channels, **factory_kwargs) if cond2_channels is not None else None
|
||||
|
||||
# Fixed Sobol (low-discrepancy) 3D anchor positions for the latent tokens, used as positional encoding.
|
||||
# The embedder is parameter-free and the anchors are fixed, precompute once.
|
||||
sobol_seq = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123).draw(q_token_length)
|
||||
pos_emb = PcdAbsolutePositionEmbedder(model_channels)(sobol_seq.unsqueeze(0))
|
||||
self.register_buffer("pos_emb", pos_emb, persistent=False)
|
||||
|
||||
# RePo3DRotaryEmbedding layers for the refiner and main blocks
|
||||
repo_kwargs = dict(num_heads=self.num_heads, head_dim=num_head_channels, **op_kwargs)
|
||||
self.noise_repo_layers = nn.ModuleList(
|
||||
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)])
|
||||
self.context_repo_layers = nn.ModuleList(
|
||||
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)])
|
||||
self.repo_layers = nn.ModuleList(
|
||||
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_blocks)])
|
||||
|
||||
# Refiner blocks
|
||||
block_kwargs = dict(num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, use_rope=True, qk_rms_norm=self.qk_rms_norm, **op_kwargs)
|
||||
self.noise_refiner = nn.ModuleList(
|
||||
[UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_refiner_blocks)])
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[UnifiedTransformerBlock(model_channels, modulation=False, **block_kwargs) for _ in range(num_refiner_blocks)])
|
||||
|
||||
self.cam_refiner = MLP(self.cam_channels, model_channels, model_channels, **op_kwargs)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_blocks)])
|
||||
|
||||
self.shift_table = nn.Parameter(torch.empty(1, 2, model_channels, **factory_kwargs))
|
||||
self.out_layer = operations.Linear(model_channels, out_channels, **factory_kwargs)
|
||||
self.cam_out_layer = operations.Linear(model_channels, cam_channels, **factory_kwargs)
|
||||
|
||||
def forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, t, context, ref_latents, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
# x is the unpacked nested latent: [latent (B,8192,in_channels), camera (B,1,cam_channels)].
|
||||
# context == feature1.
|
||||
z, camera = x[0], x[1]
|
||||
feat1 = context
|
||||
|
||||
h_x = self.input_layer(z)
|
||||
h_cond = self.cond_embedder(feat1)
|
||||
if ref_latents is not None and self.cond_embedder2 is not None:
|
||||
# Flatten the Flux2 VAE latent (B,128,h,w) to a token sequence and front-pad to feat1's length
|
||||
# (the pad count = feat1's prefix tokens: DINOv3 cls + registers), then add to the context.
|
||||
feat2 = ref_latents[0].flatten(2).transpose(1, 2)
|
||||
feat2 = F.pad(feat2, (0, 0, feat1.shape[1] - feat2.shape[1], 0))
|
||||
h_cond = h_cond + self.cond_embedder2(feat2.to(h_cond.dtype))
|
||||
t_emb = self.t_embedder(t)
|
||||
t_mod = self.adaLN_modulation(t_emb) if self.share_mod else t_emb
|
||||
|
||||
h_x = h_x + self.pos_emb.to(z)
|
||||
|
||||
for i, block in enumerate(self.noise_refiner):
|
||||
h_x = block(h_x, mod=t_mod, rotary_emb=self.noise_repo_layers[i](h_x), transformer_options=transformer_options)
|
||||
|
||||
for i, block in enumerate(self.context_refiner):
|
||||
h_cond = block(h_cond, mod=None, rotary_emb=self.context_repo_layers[i](h_cond), transformer_options=transformer_options)
|
||||
|
||||
cam = camera.to(z)
|
||||
h_cam = self.cam_refiner(cam)
|
||||
h = torch.cat([h_x, h_cond, h_cam], dim=1)
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
h = block(h, mod=t_mod, rotary_emb=self.repo_layers[i](h), transformer_options=transformer_options)
|
||||
|
||||
h_x = F.layer_norm(h[:, :z.shape[1]].float(), h.shape[-1:]).to(z)
|
||||
h_cam = F.layer_norm(h[:, -cam.shape[1]:].float(), h.shape[-1:]).to(z)
|
||||
|
||||
shift, scale = (comfy.model_management.cast_to(self.shift_table, t_emb.dtype, t_emb.device) + t_emb.unsqueeze(1)).chunk(2, dim=1)
|
||||
scale = 1 + scale
|
||||
h_x = torch.addcmul(shift, h_x, scale)
|
||||
h_cam = torch.addcmul(shift, h_cam, scale)
|
||||
|
||||
return self.out_layer(h_x), self.cam_out_layer(h_cam)
|
||||
91
comfy/ldm/triposplat/preview.py
Normal file
91
comfy/ldm/triposplat/preview.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Live preview for TripoSplat: decode an x0 estimate into a coarse gaussian splat and render it with a perspective orbit camera.
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
_C0 = 0.28209479177387814
|
||||
_LATENT_TOKENS = 8192 # q_token_length
|
||||
_LATENT_CH = 16 # in_channels
|
||||
_OBJECT_TO_VIEWER = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], np.float32) # object frame -> viewer Y-up frame
|
||||
|
||||
|
||||
def _view_matrix(yaw_deg, pitch_deg):
|
||||
y, p = np.radians(yaw_deg), np.radians(pitch_deg)
|
||||
Ry = np.array([[np.cos(y), 0, np.sin(y)], [0, 1, 0], [-np.sin(y), 0, np.cos(y)]], np.float32)
|
||||
Rx = np.array([[1, 0, 0], [0, np.cos(p), -np.sin(p)], [0, np.sin(p), np.cos(p)]], np.float32)
|
||||
return Rx @ Ry
|
||||
|
||||
|
||||
def render_splat(xyz, rgb, scale, opacity=None, yaw=35.0, pitch=30.0, size=320, min_px=2, gain=1.0,
|
||||
max_px=9, min_opacity=0.0, fov=35.0, dist=2.2):
|
||||
# Project gaussian centers with a perspective camera and paint each as a filled disk whose screen
|
||||
# radius follows the gaussian's world-space scale, composited with a nearest-wins z-buffer.
|
||||
# gain scales the footprint (≈ std spanned), `min_px`/`max_px` clamp the on-screen radius.
|
||||
|
||||
pts = xyz.astype(np.float32) @ _OBJECT_TO_VIEWER.T
|
||||
v = pts @ _view_matrix(yaw, pitch).T
|
||||
zc = v[:, 2] + dist
|
||||
keep = zc > 1e-2
|
||||
if opacity is not None and min_opacity > 0.0: # culls gaussians with very low opacity
|
||||
keep = keep & (opacity > min_opacity)
|
||||
v, zc, scale = v[keep], zc[keep], scale[keep]
|
||||
col = (np.clip(rgb, 0, 1)[:, :3] * 255).astype(np.uint8)[keep]
|
||||
if v.shape[0] == 0:
|
||||
return Image.fromarray(np.zeros((size, size, 3), np.uint8))
|
||||
f = (size / 2) / np.tan(np.radians(fov) / 2)
|
||||
cx = size / 2 + f * v[:, 0] / zc
|
||||
cy = size / 2 + f * v[:, 1] / zc
|
||||
radius = np.clip(np.round(f * scale / zc * gain), min_px, max_px).astype(np.int32)
|
||||
|
||||
# Expand each splat to its disk pixels, bucketed by integer radius so it stays vectorized.
|
||||
px, py, pz, pc = [], [], [], []
|
||||
for r in range(int(radius.min()), int(radius.max()) + 1):
|
||||
m = radius == r
|
||||
if not m.any():
|
||||
continue
|
||||
dy, dx = np.mgrid[-r:r + 1, -r:r + 1]
|
||||
disk = (dx * dx + dy * dy) <= r * r
|
||||
ox, oy = dx[disk], dy[disk]
|
||||
px.append((cx[m, None] + ox).ravel())
|
||||
py.append((cy[m, None] + oy).ravel())
|
||||
pz.append(np.repeat(zc[m], ox.size))
|
||||
pc.append(np.repeat(col[m], ox.size, axis=0))
|
||||
px, py = np.concatenate(px), np.concatenate(py)
|
||||
pz, pc = np.concatenate(pz), np.concatenate(pc)
|
||||
xi = np.clip(px, 0, size - 1).astype(np.int64)
|
||||
yi = np.clip(py, 0, size - 1).astype(np.int64)
|
||||
|
||||
# Nearest-wins z-buffer: pack (quantized depth, source index), per-pixel min picks the closest
|
||||
# splat, then decode the winning index back to its color.
|
||||
pid = yi * size + xi
|
||||
q = np.clip((pz * 1024.0).astype(np.int64), 0, (1 << 20) - 1) # near = small
|
||||
key = (q << 32) | np.arange(pid.size, dtype=np.int64)
|
||||
buf = np.full(size * size, 1 << 62, np.int64)
|
||||
np.minimum.at(buf, pid, key)
|
||||
img = np.zeros((size * size, 3), np.uint8)
|
||||
hit = buf < (1 << 62)
|
||||
img[hit] = pc[buf[hit] & 0xFFFFFFFF]
|
||||
return Image.fromarray(img.reshape(size, size, 3))
|
||||
|
||||
|
||||
def _extract_latent(x0):
|
||||
# x0 from the sampler callback is the nested latent packed to (B, 1, TOKENS*CH + 1*5);
|
||||
# the plain single-latent case is (B, TOKENS, CH). Return the (B, TOKENS, CH) latent stream.
|
||||
if x0.ndim == 3 and x0.shape[1] == _LATENT_TOKENS and x0.shape[2] == _LATENT_CH:
|
||||
return x0
|
||||
flat = x0.reshape(x0.shape[0], -1)
|
||||
return flat[:, :_LATENT_TOKENS * _LATENT_CH].reshape(x0.shape[0], _LATENT_TOKENS, _LATENT_CH)
|
||||
|
||||
|
||||
def decode_x0_to_image(decoder, x0, cfg):
|
||||
# Decode x0 at a coarse octree level / few gaussians and render a preview image.
|
||||
latent = _extract_latent(x0)
|
||||
fsm = decoder.first_stage_model
|
||||
gaussian = fsm.decode(latent.to(decoder.device, decoder.vae_dtype),
|
||||
num_gaussians=cfg.get("gaussians", 16384), level=cfg.get("level", 5))[0]
|
||||
xyz = gaussian.get_xyz.float().cpu().numpy()
|
||||
rgb = gaussian._features_dc.float().cpu().numpy()[:, 0, :] * _C0 + 0.5
|
||||
scale = gaussian.get_scaling.float().cpu().numpy().max(axis=1) # per-splat world radius (largest axis)
|
||||
opacity = gaussian.get_opacity.float().cpu().numpy()[:, 0]
|
||||
return render_splat(xyz, rgb, scale, opacity=opacity, yaw=cfg.get("yaw", 35.0), pitch=cfg.get("pitch", 30.0),
|
||||
size=cfg.get("size", 320), min_px=1, gain=1.0, max_px=cfg.get("point_size", 3),
|
||||
min_opacity=0.01)
|
||||
382
comfy/ldm/triposplat/vae.py
Normal file
382
comfy/ldm/triposplat/vae.py
Normal file
@ -0,0 +1,382 @@
|
||||
# TripoSplat gaussian decoder ("VAE"): an octree probability decoder picks point coords, then an
|
||||
# elastic-gaussian decoder predicts per-point gaussian params. OctreeGaussianDecoder.decode() returns
|
||||
# a Gaussian. The octree sampler uses the global torch RNG (no generator) like upstream, so seed it for repeatable decodes.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
from .gaussian import build_gaussian_models
|
||||
from .model import MultiHeadRMSNorm, MLP, PcdAbsolutePositionEmbedder, attention
|
||||
|
||||
|
||||
# Quasi-random sampling utilities (pure functions, dtype/device-agnostic)
|
||||
|
||||
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
|
||||
|
||||
|
||||
def radical_inverse(base, n):
|
||||
val = 0
|
||||
inv_base = 1.0 / base
|
||||
inv_base_n = inv_base
|
||||
while n > 0:
|
||||
digit = n % base
|
||||
val += digit * inv_base_n
|
||||
n //= base
|
||||
inv_base_n *= inv_base
|
||||
return val
|
||||
|
||||
|
||||
def halton_sequence(dim, n):
|
||||
return [radical_inverse(PRIMES[i], n) for i in range(dim)]
|
||||
|
||||
|
||||
def hammersley_sequence(dim, n, num_samples):
|
||||
return [n / num_samples] + halton_sequence(dim - 1, n)
|
||||
|
||||
|
||||
def sample_probs(probs, counts, generator=None):
|
||||
# Systematic resampling: distribute counts[r] draws across the P bins of row r
|
||||
batch_shape = counts.shape
|
||||
R = counts.numel()
|
||||
P = probs.size(-1)
|
||||
device = probs.device
|
||||
probs = probs.reshape(R, P).to(torch.float32).clamp_min(0)
|
||||
counts = counts.reshape(R).to(device=device, dtype=torch.long)
|
||||
|
||||
row_sums = probs.sum(1, keepdim=True)
|
||||
probs = torch.where(row_sums == 0, probs.new_tensor(1.0 / P), probs / row_sums.clamp_min(1))
|
||||
cdf = probs.cumsum(dim=1).clamp(max=1.0 - 1e-12)
|
||||
|
||||
Nmax = int(counts.max())
|
||||
if Nmax == 0:
|
||||
return counts.new_zeros(*batch_shape, P)
|
||||
cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1)
|
||||
grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax)
|
||||
u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded)
|
||||
idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1)
|
||||
weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r]
|
||||
out = torch.zeros(R, P, dtype=torch.float32, device=device)
|
||||
out.scatter_add_(1, idx, weight)
|
||||
return out.to(torch.long).view(*batch_shape, P)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, channels, num_heads, ctx_channels=None, type="self", qkv_bias=True, qk_rms_norm=False,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert channels % num_heads == 0
|
||||
self.channels = channels
|
||||
self.head_dim = channels // num_heads
|
||||
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
||||
self.num_heads = num_heads
|
||||
self._type = type
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
if self._type == "self":
|
||||
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, dtype=dtype, device=device)
|
||||
if self.qk_rms_norm:
|
||||
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
||||
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
||||
self.to_out = operations.Linear(channels, channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, context=None):
|
||||
B, L, C = x.shape
|
||||
if self._type == "self":
|
||||
q, k, v = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1).unbind(dim=2)
|
||||
else:
|
||||
Lkv = context.shape[1]
|
||||
q = self.to_q(x).reshape(B, L, self.num_heads, -1)
|
||||
k, v = self.to_kv(context).reshape(B, Lkv, 2, self.num_heads, -1).unbind(dim=2)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k = self.k_rms_norm(k)
|
||||
h = attention(q, k, v)
|
||||
return self.to_out(h.reshape(B, L, -1))
|
||||
|
||||
|
||||
# Octree probability decoder
|
||||
|
||||
class LevelEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, max_period=1024,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.max_period = max_period
|
||||
|
||||
@staticmethod
|
||||
def level_embedding(t, dim, max_period=1024):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None] * 2 * torch.pi
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
emb = self.level_embedding(t, self.frequency_embedding_size, self.max_period)
|
||||
return self.mlp(emb.to(self.mlp[0].weight.dtype))
|
||||
|
||||
|
||||
class ModulatedTransformerCrossOnlyBlock(nn.Module):
|
||||
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, share_mod=False,
|
||||
qk_rms_norm_cross=True, qkv_bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads,
|
||||
type="cross", qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, mod, context):
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
|
||||
x = torch.addcmul(x, self.cross_attn(h, context), gate_msa.unsqueeze(1))
|
||||
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
|
||||
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
|
||||
return x
|
||||
|
||||
|
||||
class OctreeProbabilityFixedlenDecoder(nn.Module):
|
||||
# Cross-attention transformer over octree coords -> per-node 8-way child occupancy logits.
|
||||
def __init__(self, model_channels=1024, cond_channels=16, num_blocks=4, num_heads=16,
|
||||
num_head_channels=64, mlp_ratio=4.0, share_mod=True,
|
||||
qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.model_channels = model_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.share_mod = share_mod
|
||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
|
||||
self.l_embedder = LevelEmbedder(model_channels, dtype=dtype, device=device, operations=operations)
|
||||
if share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, dtype=dtype, device=device))
|
||||
if cond_channels is not None:
|
||||
self.blocks = nn.ModuleList([
|
||||
ModulatedTransformerCrossOnlyBlock(
|
||||
model_channels, ctx_channels=cond_channels, num_heads=self.num_heads,
|
||||
mlp_ratio=self.mlp_ratio, qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||
share_mod=self.share_mod, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
self.out_proj = operations.Linear(model_channels, 8, dtype=dtype, device=device)
|
||||
self.in_proj = operations.Linear(3, model_channels, dtype=dtype, device=device)
|
||||
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
|
||||
|
||||
def forward(self, x, l, cond):
|
||||
d = next(self.parameters()).dtype
|
||||
B, L, _ = x.shape
|
||||
h = self.in_proj(x.to(d)) + self.pos_embedder(x.reshape(-1, 3)).reshape(B, L, -1).to(d)
|
||||
h = self.input_layer(h)
|
||||
l_emb = self.l_embedder(l)
|
||||
if self.share_mod:
|
||||
l_emb = self.adaLN_modulation(l_emb)
|
||||
cond = cond.to(d)
|
||||
for block in self.blocks:
|
||||
h = block(h, l_emb, cond)
|
||||
h = F.layer_norm(h.float(), h.shape[-1:]).to(d)
|
||||
logits = self.out_proj(h)
|
||||
return {"logits": logits, "probs": torch.softmax(logits, dim=-1)}
|
||||
|
||||
@staticmethod
|
||||
def sample(model, cond, num_points, level, temperature=1.0, generator=None):
|
||||
B = cond.shape[0]
|
||||
device = cond.device
|
||||
child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]],
|
||||
dtype=torch.long, device=device)
|
||||
prev_coords_int = torch.zeros(B, 1, 3, dtype=torch.long, device=device)
|
||||
prev_counts = torch.full((B, 1), num_points, dtype=torch.long, device=device)
|
||||
prev_log_probs = torch.zeros(B, 1, dtype=torch.float32, device=device)
|
||||
batch_indices_range = torch.arange(B, device=device).unsqueeze(1)
|
||||
|
||||
for lv in range(1, level + 1):
|
||||
res_p = 1 << (lv - 1)
|
||||
res = 1 << lv
|
||||
parent_coords_norm = (prev_coords_int.to(torch.float32) + 0.5) / res_p
|
||||
res_tensor = torch.full((B,), res, dtype=torch.long, device=device)
|
||||
pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature
|
||||
pred_probs = torch.softmax(pred_logits, dim=-1)
|
||||
pred_log_probs = torch.log_softmax(pred_logits, dim=-1)
|
||||
sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2)
|
||||
pred_log_probs = pred_log_probs.flatten(1, 2)
|
||||
prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1)
|
||||
child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2)
|
||||
mask = sampled > 0
|
||||
max_valid = mask.sum(dim=1).max().item()
|
||||
scatter_indices = mask.cumsum(dim=1) - 1
|
||||
valid_scatter_indices = scatter_indices[mask]
|
||||
valid_batch_indices = batch_indices_range.expand_as(mask)[mask]
|
||||
next_prev_coords_int = torch.zeros(B, max_valid, 3, dtype=child_coords_int.dtype, device=device)
|
||||
next_prev_coords_int[valid_batch_indices, valid_scatter_indices] = child_coords_int[mask]
|
||||
next_prev_counts = torch.zeros(B, max_valid, dtype=sampled.dtype, device=device)
|
||||
next_prev_counts[valid_batch_indices, valid_scatter_indices] = sampled[mask]
|
||||
next_prev_log_probs = torch.zeros(B, max_valid, dtype=prev_log_probs.dtype, device=device)
|
||||
next_prev_log_probs[valid_batch_indices, valid_scatter_indices] = (prev_log_probs_expanded + pred_log_probs)[mask]
|
||||
prev_coords_int = next_prev_coords_int
|
||||
prev_counts = next_prev_counts
|
||||
prev_log_probs = next_prev_log_probs
|
||||
|
||||
res = 1 << level
|
||||
prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points)
|
||||
coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1)
|
||||
rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device)
|
||||
coords_norm = (coords_int.to(torch.float32) + rand) / res
|
||||
return {"points": coords_norm, "log_probs": prev_log_probs}
|
||||
|
||||
|
||||
# Elastic gaussian decoder
|
||||
|
||||
class TransformerCrossBlock(nn.Module):
|
||||
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0,
|
||||
qk_rms_norm=True, qk_rms_norm_cross=True, qkv_bias=True,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.self_attn = MultiHeadAttention(channels, num_heads=num_heads, type="self", qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm, dtype=dtype, device=device, operations=operations)
|
||||
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, type="cross",
|
||||
qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, context):
|
||||
x = x + self.self_attn(self.norm1(x))
|
||||
x = x + self.cross_attn(self.norm2(x), context)
|
||||
x = x + self.mlp(self.norm3(x))
|
||||
return x
|
||||
|
||||
|
||||
class ElasticGaussianFixedlenDecoder(nn.Module):
|
||||
# Cross-attention transformer over sampled octree points -> per-point gaussian params.
|
||||
def __init__(self, in_channels=3, model_channels=1024, cond_channels=16, num_blocks=16, num_heads=16,
|
||||
num_head_channels=64, mlp_ratio=4.0, *, representation_config=None,
|
||||
qk_rms_norm=True, qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.rep_config = representation_config or dict(
|
||||
lr=dict(_xyz=1.0, _features_dc=1.0, _opacity=1.0, _scaling=1.0, _rotation=0.1),
|
||||
perturb_offset=True, perturbe_size=1.5, offset_scale=0.05, num_gaussians=32,
|
||||
filter_kernel_size_3d=0.0009, scaling_bias=0.004, opacity_bias=0.1,
|
||||
scaling_activation="softplus",
|
||||
)
|
||||
self.out_channels = self._calc_layout()
|
||||
self.model_channels = model_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
|
||||
if cond_channels is not None:
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerCrossBlock(model_channels, ctx_channels=cond_channels,
|
||||
num_heads=self.num_heads, mlp_ratio=self.mlp_ratio,
|
||||
qk_rms_norm=qk_rms_norm, qk_rms_norm_cross=qk_rms_norm_cross,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
self.in_proj = operations.Linear(in_channels, model_channels, dtype=dtype, device=device)
|
||||
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
|
||||
self.out_proj = operations.Linear(model_channels, self.out_channels, dtype=dtype, device=device)
|
||||
self._build_perturbation()
|
||||
|
||||
def _calc_layout(self):
|
||||
ng = self.rep_config['num_gaussians']
|
||||
self.layout = {
|
||||
'_xyz': {'shape': (ng, 3), 'size': ng * 3},
|
||||
'_features_dc': {'shape': (ng, 1, 3), 'size': ng * 3},
|
||||
'_scaling': {'shape': (ng, 3), 'size': ng * 3},
|
||||
'_rotation': {'shape': (ng, 4), 'size': ng * 4},
|
||||
'_opacity': {'shape': (ng, 1), 'size': ng},
|
||||
}
|
||||
self.layout['_offset_scale'] = {'shape': (ng, 1), 'size': ng}
|
||||
start = 0
|
||||
for k, v in self.layout.items():
|
||||
v['range'] = (start, start + v['size'])
|
||||
start += v['size']
|
||||
return start
|
||||
|
||||
def _build_perturbation(self):
|
||||
ng = self.rep_config['num_gaussians']
|
||||
perturbation = torch.tensor([hammersley_sequence(3, i, ng) for i in range(ng)]).float()
|
||||
perturbation = torch.atanh((perturbation * 2 - 1) / self.rep_config['perturbe_size'])
|
||||
self.register_buffer('points_offset_perturbation', perturbation)
|
||||
base = torch.tensor(self.rep_config['offset_scale'])
|
||||
self.register_buffer('base_offset_scale', torch.log(torch.exp(base) - 1.0))
|
||||
|
||||
def _get_offset(self, h):
|
||||
B = h.shape[0]
|
||||
r = self.layout['_offset_scale']['range']
|
||||
_offset_scale = F.softplus(
|
||||
h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_offset_scale']['shape'])
|
||||
+ comfy.model_management.cast_to(self.base_offset_scale, h.dtype, h.device))
|
||||
|
||||
r = self.layout['_xyz']['range']
|
||||
offset = h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_xyz']['shape'])
|
||||
offset = offset * self.rep_config['lr']['_xyz']
|
||||
if self.rep_config['perturb_offset']:
|
||||
offset = offset + comfy.model_management.cast_to(self.points_offset_perturbation, offset.dtype, offset.device)
|
||||
offset = torch.tanh(offset) * 0.5 * self.rep_config['perturbe_size']
|
||||
offset = offset * _offset_scale
|
||||
return offset
|
||||
|
||||
def forward(self, x=None, cond=None):
|
||||
pcd = x["points"]
|
||||
d = next(self.parameters()).dtype
|
||||
B, L, _ = pcd.shape
|
||||
h = self.in_proj(pcd.to(d)) + self.pos_embedder(pcd.reshape(-1, 3)).reshape(B, L, -1).to(d)
|
||||
h = self.input_layer(h)
|
||||
cond = cond.to(d)
|
||||
for block in self.blocks:
|
||||
h = block(h, cond)
|
||||
h = F.layer_norm(h.float(), h.shape[-1:]).to(h.dtype)
|
||||
return {"features": self.out_proj(h)}
|
||||
|
||||
|
||||
# Combined octree gaussian decoder (comfy first-stage model)
|
||||
|
||||
class OctreeGaussianDecoder(nn.Module):
|
||||
_MAX_VOXEL_LEVEL = 8
|
||||
|
||||
def __init__(self, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if operations is None:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
self.octree = OctreeProbabilityFixedlenDecoder(dtype=dtype, device=device, operations=operations)
|
||||
self.gs = ElasticGaussianFixedlenDecoder(dtype=dtype, device=device, operations=operations)
|
||||
|
||||
@property
|
||||
def gaussians_per_point(self) -> int:
|
||||
return self.gs.rep_config['num_gaussians']
|
||||
|
||||
def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None):
|
||||
# level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews.
|
||||
# generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG.
|
||||
level = self._MAX_VOXEL_LEVEL if level is None else level
|
||||
num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point)
|
||||
points_pred = OctreeProbabilityFixedlenDecoder.sample(
|
||||
self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator,
|
||||
)
|
||||
pred = self.gs(x=points_pred, cond=latent)
|
||||
return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item
|
||||
@ -47,6 +47,7 @@ import comfy.ldm.wan.model_animate
|
||||
import comfy.ldm.wan.ar_model
|
||||
import comfy.ldm.wan.model_wandancer
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.triposplat.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.chroma_radiance.model
|
||||
@ -1812,6 +1813,24 @@ class Hunyuan3Dv2_1(BaseModel):
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class TripoSplat(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.triposplat.model.LatentSeqMMFlowModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None) # DINOv3 token sequence -> cross-attention context.
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
ref_latents = kwargs.get("reference_latents", None) # Flux2 VAE image latent -> additive second conditioning.
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = comfy.conds.CONDList(list(ref_latents))
|
||||
latent_shapes = kwargs.get("latent_shapes", None) # {latent, camera} nested latent
|
||||
if latent_shapes is not None:
|
||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||
return out
|
||||
|
||||
|
||||
class HiDream(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
|
||||
|
||||
@ -355,6 +355,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["use_x0"] = True
|
||||
else:
|
||||
dit_config["use_x0"] = False
|
||||
if "{}__sequential__".format(key_prefix) in state_dict_keys: # sequential txt_ids
|
||||
dit_config["use_sequential_txt_ids"] = True
|
||||
else:
|
||||
dit_config["use_sequential_txt_ids"] = False
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
@ -718,6 +722,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}cam_out_layer.weight'.format(key_prefix) in state_dict_keys and '{}repo_layers.0.final_map.weight'.format(key_prefix) in state_dict_keys: # TripoSplat
|
||||
return {"image_model": "triposplat"}
|
||||
|
||||
if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1
|
||||
return {"image_model": "hidream_o1"}
|
||||
|
||||
|
||||
11
comfy/sd.py
11
comfy/sd.py
@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.triposplat.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.cogvideo.vae
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
@ -908,6 +909,16 @@ class VAE:
|
||||
#Force cast it for --disable-dynamic-vram users until there is a true core fix.
|
||||
if not comfy.memory_management.aimdo_enabled:
|
||||
self.disable_offload = True
|
||||
elif "gs.base_offset_scale" in sd and "octree.out_proj.weight" in sd: # TripoSplat octree gaussian decoder
|
||||
self.first_stage_model = comfy.ldm.triposplat.vae.OctreeGaussianDecoder()
|
||||
self.latent_channels = 16
|
||||
self.latent_dim = 1
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
# The generic VAE.encode/decode path isn't used: VAEDecodeTripoSplat calls the gaussian
|
||||
# decoder directly (structured GaussianSplat objects, not a tensor and reserves VRAM itself from num_gaussians.
|
||||
def _no_generic_io(*args, **kwargs):
|
||||
raise RuntimeError("TripoSplat gaussian decoder: use the 'TripoSplat Decode' (VAEDecodeTripoSplat)")
|
||||
self.memory_used_encode = self.memory_used_decode = _no_generic_io
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
|
||||
@ -1547,6 +1547,30 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2mini
|
||||
|
||||
class TripoSplat(supported_models_base.BASE):
|
||||
# Image -> 3D gaussian splat flow denoiser
|
||||
unet_config = {
|
||||
"image_model": "triposplat",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 0.6
|
||||
|
||||
latent_format = latent_formats.TripoSplat
|
||||
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.TripoSplat(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class HiDream(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hidream",
|
||||
@ -2210,6 +2234,7 @@ models = [
|
||||
Hunyuan3Dv2mini,
|
||||
Hunyuan3Dv2,
|
||||
Hunyuan3Dv2_1,
|
||||
TripoSplat,
|
||||
HiDream,
|
||||
HiDreamO1,
|
||||
Chroma,
|
||||
|
||||
@ -5,7 +5,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, SPLAT, File3D
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
from comfy_execution.utils import get_executing_context
|
||||
@ -143,6 +143,7 @@ class Types:
|
||||
VideoComponents = VideoComponents
|
||||
MESH = MESH
|
||||
VOXEL = VOXEL
|
||||
SPLAT = SPLAT
|
||||
File3D = File3D
|
||||
|
||||
|
||||
|
||||
@ -65,6 +65,12 @@ class VideoInput(ABC):
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
def get_active_trim_window(self) -> tuple[float, float]:
|
||||
"""Return the active trim as ``(start_time, duration)`` in seconds (start_time normalized
|
||||
to ``>= 0``; ``duration == 0`` means "until the end"). Default: no trim; trimmable subclasses override.
|
||||
"""
|
||||
return 0.0, 0.0
|
||||
|
||||
# Provide a default implementation, but subclasses can provide optimized versions
|
||||
# if possible.
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
|
||||
@ -75,6 +75,12 @@ class VideoFromFile(VideoInput):
|
||||
self.__file.seek(0)
|
||||
return self.__file
|
||||
|
||||
def get_active_trim_window(self) -> tuple[float, float]:
|
||||
start_time = self.__start_time
|
||||
if start_time < 0:
|
||||
start_time = max(self._get_raw_duration() + start_time, 0.0)
|
||||
return float(start_time), float(self.__duration)
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL, SVG as _SVG, File3D
|
||||
from ._util import MESH, VOXEL, SPLAT, SVG as _SVG, File3D
|
||||
|
||||
|
||||
class FolderType(str, Enum):
|
||||
@ -684,6 +684,10 @@ class Voxel(ComfyTypeIO):
|
||||
class Mesh(ComfyTypeIO):
|
||||
Type = MESH
|
||||
|
||||
@comfytype(io_type="SPLAT")
|
||||
class Splat(ComfyTypeIO):
|
||||
Type = SPLAT
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D")
|
||||
class File3DAny(ComfyTypeIO):
|
||||
@ -2320,6 +2324,7 @@ __all__ = [
|
||||
"LossMap",
|
||||
"Voxel",
|
||||
"Mesh",
|
||||
"Splat",
|
||||
"File3DAny",
|
||||
"File3DGLB",
|
||||
"File3DGLTF",
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
from .geometry_types import VOXEL, MESH, File3D
|
||||
from .geometry_types import VOXEL, MESH, SPLAT, File3D
|
||||
from .image_types import SVG
|
||||
|
||||
__all__ = [
|
||||
@ -9,6 +9,7 @@ __all__ = [
|
||||
"VideoComponents",
|
||||
"VOXEL",
|
||||
"MESH",
|
||||
"SPLAT",
|
||||
"File3D",
|
||||
"SVG",
|
||||
]
|
||||
|
||||
@ -11,13 +11,32 @@ class VOXEL:
|
||||
self.data = data
|
||||
|
||||
|
||||
class SPLAT:
|
||||
"""A batch of 3D Gaussian splats in render-ready (activated, world-space) form.
|
||||
|
||||
Tensors are (B, N, ...) and zero-padded to a common N across the batch; `counts` (B,) holds the
|
||||
real per-item lengths (None when rows are uniform and no slicing is needed). SH coefficients are
|
||||
stored as (B, N, K, 3) with K = (sh_degree + 1)**2; the DC (diffuse) term is sh[..., 0, :].
|
||||
"""
|
||||
|
||||
def __init__(self, positions: torch.Tensor, scales: torch.Tensor, rotations: torch.Tensor,
|
||||
opacities: torch.Tensor, sh: torch.Tensor, counts: torch.Tensor | None = None):
|
||||
self.positions = positions # (B, N, 3) world-space centers
|
||||
self.scales = scales # (B, N, 3) linear (positive) per-axis std
|
||||
self.rotations = rotations # (B, N, 4) quaternion wxyz (normalized)
|
||||
self.opacities = opacities # (B, N, 1) in [0, 1]
|
||||
self.sh = sh # (B, N, K, 3) spherical-harmonic color coefficients
|
||||
self.counts = counts # (B,) real lengths, or None
|
||||
|
||||
|
||||
class MESH:
|
||||
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor,
|
||||
uvs: torch.Tensor | None = None,
|
||||
vertex_colors: torch.Tensor | None = None,
|
||||
texture: torch.Tensor | None = None,
|
||||
vertex_counts: torch.Tensor | None = None,
|
||||
face_counts: torch.Tensor | None = None):
|
||||
face_counts: torch.Tensor | None = None,
|
||||
unlit: bool = False):
|
||||
|
||||
assert (vertex_counts is None) == (face_counts is None), \
|
||||
"vertex_counts and face_counts must be provided together (both or neither)"
|
||||
@ -30,6 +49,8 @@ class MESH:
|
||||
# these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed.
|
||||
self.vertex_counts = vertex_counts
|
||||
self.face_counts = face_counts
|
||||
# Render flat / emissive (no scene lighting) when saved, e.g. for gaussian-splat-derived meshes.
|
||||
self.unlit = unlit
|
||||
|
||||
|
||||
class File3D:
|
||||
|
||||
@ -1,71 +1,71 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, confloat, conint
|
||||
|
||||
|
||||
class BFLOutputFormat(str, Enum):
|
||||
png = 'png'
|
||||
jpeg = 'jpeg'
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BFLFluxExpandImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image')
|
||||
bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image')
|
||||
left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image')
|
||||
right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand')
|
||||
prompt: str = Field(...)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
seed: int | None = Field(None)
|
||||
top: int = Field(...)
|
||||
bottom: int = Field(...)
|
||||
left: int = Field(...)
|
||||
right: int = Field(...)
|
||||
steps: int = Field(...)
|
||||
guidance: float = Field(...)
|
||||
safety_tolerance: int = Field(6)
|
||||
output_format: str = Field("png")
|
||||
image: str = Field(None, description="A Base64-encoded string representing the image you wish to expand")
|
||||
|
||||
|
||||
class BFLFluxFillImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
prompt: str = Field(...)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
seed: int | None = Field(None)
|
||||
steps: int = Field(...)
|
||||
guidance: float = Field(...)
|
||||
safety_tolerance: int = Field(6)
|
||||
output_format: str = Field("png")
|
||||
image: str = Field(
|
||||
None, description="Base64-encoded string representing the image to modify. Can contain alpha mask if desired.",
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
mask: str = Field(
|
||||
None, description="Base64-encoded string representing the mask of the areas you wish to modify."
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
|
||||
|
||||
class BFLFluxEraseRequest(BaseModel):
|
||||
image: str = Field(..., description="A Base64-encoded string representing the image to erase from.")
|
||||
mask: str = Field(
|
||||
...,
|
||||
description="A Base64-encoded black/white mask matching the input dimensions; "
|
||||
"white (255) marks areas to remove, black (0) marks areas to preserve.",
|
||||
)
|
||||
image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.')
|
||||
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
|
||||
dilate_pixels: int = Field(10)
|
||||
output_format: str = Field("png")
|
||||
|
||||
|
||||
class BFLFluxVTORequest(BaseModel):
|
||||
prompt: str = Field(
|
||||
..., description="Natural-language styling instruction. Required field, but may be an empty string."
|
||||
)
|
||||
person: str = Field(..., description="A Base64-encoded string representing the person image.")
|
||||
garment: str = Field(..., description="A Base64-encoded string representing the garment reference image.")
|
||||
seed: int | None = Field(None)
|
||||
safety_tolerance: int = Field(5)
|
||||
output_format: str = Field("png")
|
||||
|
||||
|
||||
class BFLFluxProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.')
|
||||
height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
|
||||
# image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
|
||||
# None, description='Blend between the prompt and the image prompt.'
|
||||
# )
|
||||
prompt: str = Field(...)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
seed: int | None = Field(None)
|
||||
width: int = Field(1024, description="Must be a multiple of 32.")
|
||||
height: int = Field(768, description="Must be a multiple of 32.")
|
||||
safety_tolerance: int = Field(6)
|
||||
output_format: str = Field("png")
|
||||
image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
|
||||
|
||||
|
||||
class Flux2ProGenerateRequest(BaseModel):
|
||||
@ -83,55 +83,37 @@ class Flux2ProGenerateRequest(BaseModel):
|
||||
input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
safety_tolerance: int | None = Field(
|
||||
5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5
|
||||
)
|
||||
output_format: str | None = Field(
|
||||
"png", description="Output format for the generated image. Can be 'jpeg' or 'png'."
|
||||
)
|
||||
safety_tolerance: int = Field(5)
|
||||
output_format: str = Field("png")
|
||||
|
||||
|
||||
class BFLFluxKontextProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
|
||||
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process')
|
||||
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=2)] = Field(
|
||||
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
prompt: str = Field(...)
|
||||
input_image: str | None = Field(None, description="Image to edit in base64 format")
|
||||
seed: int | None = Field(None)
|
||||
guidance: float = Field(...)
|
||||
steps: int = Field(...)
|
||||
safety_tolerance: int = Field(2)
|
||||
output_format: str = Field("png")
|
||||
aspect_ratio: str | None = Field(None)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
|
||||
|
||||
class BFLFluxProUltraGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.')
|
||||
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
|
||||
image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
|
||||
None, description='Blend between the prompt and the image prompt.'
|
||||
)
|
||||
prompt: str = Field(...)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
seed: int | None = Field(None)
|
||||
aspect_ratio: str | None = Field(None)
|
||||
safety_tolerance: int = Field(6)
|
||||
output_format: str = Field("png")
|
||||
raw: bool | None = Field(None)
|
||||
image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
|
||||
image_prompt_strength: float | None = Field(None)
|
||||
|
||||
|
||||
class BFLFluxProGenerateResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier for the generation task.")
|
||||
polling_url: str = Field(..., description="URL to poll for the generation result.")
|
||||
id: str = Field(...)
|
||||
polling_url: str = Field(...)
|
||||
cost: float | None = Field(None, description="Price in cents")
|
||||
|
||||
|
||||
@ -145,7 +127,7 @@ class BFLStatus(str, Enum):
|
||||
|
||||
|
||||
class BFLFluxStatusResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier for the generation task.")
|
||||
status: BFLStatus = Field(..., description="The status of the task.")
|
||||
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
|
||||
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)
|
||||
id: str = Field(...)
|
||||
status: BFLStatus = Field(...)
|
||||
result: dict[str, Any] | None = Field(None)
|
||||
progress: float | None = Field(None, ge=0.0, le=1.0)
|
||||
|
||||
@ -155,7 +155,7 @@ class ClaudeNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ClaudeNode",
|
||||
display_name="Anthropic Claude",
|
||||
category="text/partner/Anthropic",
|
||||
category="partner/text/Anthropic",
|
||||
essentials_category="Text Generation",
|
||||
description="Generate text responses with Anthropic's Claude models. "
|
||||
"Provide a text prompt and optionally one or more images for multimodal context.",
|
||||
|
||||
@ -206,7 +206,7 @@ class BeebleSwitchXVideoEdit(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="BeebleSwitchXVideoEdit",
|
||||
display_name="Beeble SwitchX Video Edit",
|
||||
category="video/partner/Beeble",
|
||||
category="partner/video/Beeble",
|
||||
description=(
|
||||
"Edit a video with Beeble SwitchX. Switches anything in the scene (background, "
|
||||
"lighting, costume) while preserving the original subject's pixels and motion. "
|
||||
@ -302,7 +302,7 @@ class BeebleSwitchXImageEdit(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="BeebleSwitchXImageEdit",
|
||||
display_name="Beeble SwitchX Image Edit",
|
||||
category="image/partner/Beeble",
|
||||
category="partner/image/Beeble",
|
||||
description=(
|
||||
"Edit a single image with Beeble SwitchX. Switches anything in the scene "
|
||||
"(background, lighting, costume) while preserving the original subject's pixels. "
|
||||
|
||||
@ -4,17 +4,20 @@ from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bfl import (
|
||||
BFLFluxEraseRequest,
|
||||
BFLFluxExpandImageRequest,
|
||||
BFLFluxFillImageRequest,
|
||||
BFLFluxKontextProGenerateRequest,
|
||||
BFLFluxProGenerateResponse,
|
||||
BFLFluxProUltraGenerateRequest,
|
||||
BFLFluxStatusResponse,
|
||||
BFLFluxVTORequest,
|
||||
BFLStatus,
|
||||
Flux2ProGenerateRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
convert_mask_to_image,
|
||||
download_url_to_image_tensor,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
@ -22,19 +25,11 @@ from comfy_api_nodes.util import (
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
validate_aspect_ratio_string,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
|
||||
def convert_mask_to_image(mask: Input.Image):
|
||||
"""
|
||||
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
|
||||
"""
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask = torch.cat([mask] * 3, dim=-1)
|
||||
return mask
|
||||
|
||||
|
||||
class FluxProUltraImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -42,7 +37,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="FluxProUltraImageNode",
|
||||
display_name="Flux 1.1 [pro] Ultra Image",
|
||||
category="image/partner/BFL",
|
||||
category="partner/image/BFL",
|
||||
description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -160,7 +155,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id=cls.NODE_ID,
|
||||
display_name=cls.DISPLAY_NAME,
|
||||
category="image/partner/BFL",
|
||||
category="partner/image/BFL",
|
||||
description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -282,7 +277,7 @@ class FluxProExpandNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="FluxProExpandNode",
|
||||
display_name="Flux.1 Expand Image",
|
||||
category="image/partner/BFL",
|
||||
category="partner/image/BFL",
|
||||
description="Outpaints image based on prompt.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -419,7 +414,7 @@ class FluxProFillNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="FluxProFillNode",
|
||||
display_name="Flux.1 Fill Image",
|
||||
category="image/partner/BFL",
|
||||
category="partner/image/BFL",
|
||||
description="Inpaints image based on mask and prompt.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -519,6 +514,163 @@ class FluxProFillNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class FluxEraseNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="FluxEraseNode",
|
||||
display_name="Flux Erase Image",
|
||||
category="partner/image/BFL",
|
||||
description="Removes the masked object from an image and reconstructs the background. "
|
||||
"Paint the mask over what you want to erase.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Mask.Input("mask", tooltip="White areas are removed; black areas are preserved."),
|
||||
IO.Int.Input(
|
||||
"dilate_pixels",
|
||||
default=10,
|
||||
min=0,
|
||||
max=25,
|
||||
tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"range_usd","min_usd":0.03,"max_usd":0.06,"format":{"approximate":true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
mask: Input.Image,
|
||||
dilate_pixels: int = 10,
|
||||
) -> IO.NodeOutput:
|
||||
validate_image_dimensions(image, min_width=256, min_height=256)
|
||||
mask = resize_mask_to_image(mask, image)
|
||||
mask = tensor_to_base64_string(convert_mask_to_image(mask))
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/erase-v1", method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxEraseRequest(
|
||||
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
|
||||
mask=mask,
|
||||
dilate_pixels=dilate_pixels,
|
||||
),
|
||||
)
|
||||
|
||||
def price_extractor(_r: BaseModel) -> float | None:
|
||||
return None if initial_response.cost is None else initial_response.cost / 100
|
||||
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
price_extractor=price_extractor,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class FluxVTONode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="FluxVTONode",
|
||||
display_name="Flux Virtual Try-On",
|
||||
category="partner/image/BFL",
|
||||
description="Virtual try-on: dresses the person in the provided garment.",
|
||||
inputs=[
|
||||
IO.Image.Input("person", tooltip="Image of the person to dress."),
|
||||
IO.Image.Input("garment", tooltip="Image of the garment to apply."),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Optional natural-language styling instruction (e.g. how the garment should fit).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"range_usd","min_usd":0.0375,"max_usd":0.075,"format":{"approximate":true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
person: Input.Image,
|
||||
garment: Input.Image,
|
||||
prompt: str = "",
|
||||
seed: int = 0,
|
||||
) -> IO.NodeOutput:
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/vto-v1", method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxVTORequest(
|
||||
prompt=prompt,
|
||||
person=tensor_to_base64_string(person[:, :, :, :3]),
|
||||
garment=tensor_to_base64_string(garment[:, :, :, :3]),
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
|
||||
def price_extractor(_r: BaseModel) -> float | None:
|
||||
return None if initial_response.cost is None else initial_response.cost / 100
|
||||
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
price_extractor=price_extractor,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class Flux2ProImageNode(IO.ComfyNode):
|
||||
|
||||
NODE_ID = "Flux2ProImageNode"
|
||||
@ -545,7 +697,7 @@ class Flux2ProImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id=cls.NODE_ID,
|
||||
display_name=cls.DISPLAY_NAME,
|
||||
category="image/partner/BFL",
|
||||
category="partner/image/BFL",
|
||||
description="Generates images synchronously based on prompt and resolution.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -716,7 +868,7 @@ class Flux2ImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Flux2ImageNode",
|
||||
display_name="Flux.2 Image",
|
||||
category="image/partner/BFL",
|
||||
category="partner/image/BFL",
|
||||
description="Generate images via Flux.2 [pro] or Flux.2 [max] from a prompt and optional reference images.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -853,6 +1005,8 @@ class BFLExtension(ComfyExtension):
|
||||
FluxKontextMaxImageNode,
|
||||
FluxProExpandNode,
|
||||
FluxProFillNode,
|
||||
FluxEraseNode,
|
||||
FluxVTONode,
|
||||
Flux2ProImageNode,
|
||||
Flux2MaxImageNode,
|
||||
Flux2ImageNode,
|
||||
|
||||
@ -31,7 +31,7 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="BriaImageEditNode",
|
||||
display_name="Bria FIBO Image Edit",
|
||||
category="image/partner/Bria",
|
||||
category="partner/image/Bria",
|
||||
description="Edit images using Bria latest model",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["FIBO"]),
|
||||
@ -169,7 +169,7 @@ class BriaRemoveImageBackground(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="BriaRemoveImageBackground",
|
||||
display_name="Bria Remove Image Background",
|
||||
category="image/partner/Bria",
|
||||
category="partner/image/Bria",
|
||||
description="Remove the background from an image using Bria RMBG 2.0.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -245,7 +245,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="BriaRemoveVideoBackground",
|
||||
display_name="Bria Remove Video Background",
|
||||
category="video/partner/Bria",
|
||||
category="partner/video/Bria",
|
||||
description="Remove the background from a video using Bria. ",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
|
||||
@ -368,7 +368,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceImageNode",
|
||||
display_name="ByteDance Image",
|
||||
category="image/partner/ByteDance",
|
||||
category="partner/image/ByteDance",
|
||||
description="Generate images using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
|
||||
@ -492,7 +492,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedreamNode",
|
||||
display_name="ByteDance Seedream 4.5 & 5.0",
|
||||
category="image/partner/ByteDance",
|
||||
category="partner/image/ByteDance",
|
||||
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -754,7 +754,7 @@ class ByteDanceSeedreamNodeV2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedreamNodeV2",
|
||||
display_name="ByteDance Seedream 4.5 & 5.0",
|
||||
category="image/partner/ByteDance",
|
||||
category="partner/image/ByteDance",
|
||||
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -920,7 +920,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceTextToVideoNode",
|
||||
display_name="ByteDance Text to Video",
|
||||
category="video/partner/ByteDance",
|
||||
category="partner/video/ByteDance",
|
||||
description="Generate video using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -1048,7 +1048,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceImageToVideoNode",
|
||||
display_name="ByteDance Image to Video",
|
||||
category="video/partner/ByteDance",
|
||||
category="partner/video/ByteDance",
|
||||
description="Generate video using ByteDance models via api based on image and prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -1185,7 +1185,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceFirstLastFrameNode",
|
||||
display_name="ByteDance First-Last-Frame to Video",
|
||||
category="video/partner/ByteDance",
|
||||
category="partner/video/ByteDance",
|
||||
description="Generate video using prompt and first and last frames.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -1333,7 +1333,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceImageReferenceNode",
|
||||
display_name="ByteDance Reference Images to Video",
|
||||
category="video/partner/ByteDance",
|
||||
category="partner/video/ByteDance",
|
||||
description="Generate video using prompt and reference images.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -1576,7 +1576,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDance2TextToVideoNode",
|
||||
display_name="ByteDance Seedance 2.0 Text to Video",
|
||||
category="video/partner/ByteDance",
|
||||
category="partner/video/ByteDance",
|
||||
description="Generate video using Seedance 2.0 models based on a text prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1677,7 +1677,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDance2FirstLastFrameNode",
|
||||
display_name="ByteDance Seedance 2.0 First-Last-Frame to Video",
|
||||
category="video/partner/ByteDance",
|
||||
category="partner/video/ByteDance",
|
||||
description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1944,7 +1944,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDance2ReferenceNode",
|
||||
display_name="ByteDance Seedance 2.0 Reference to Video",
|
||||
category="video/partner/ByteDance",
|
||||
category="partner/video/ByteDance",
|
||||
description="Generate, edit, or extend video using Seedance 2.0 with reference images, "
|
||||
"videos, and audio. Supports multimodal reference, video editing, and video extension.",
|
||||
inputs=[
|
||||
@ -2241,7 +2241,7 @@ class ByteDanceCreateImageAsset(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceCreateImageAsset",
|
||||
display_name="ByteDance Create Image Asset",
|
||||
category="image/partner/ByteDance",
|
||||
category="partner/image/ByteDance",
|
||||
description=(
|
||||
"Create a Seedance 2.0 personal image asset. Uploads the input image and "
|
||||
"registers it in the given asset group. If group_id is empty, runs a real-person "
|
||||
@ -2308,7 +2308,7 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceCreateVideoAsset",
|
||||
display_name="ByteDance Create Video Asset",
|
||||
category="video/partner/ByteDance",
|
||||
category="partner/video/ByteDance",
|
||||
description=(
|
||||
"Create a Seedance 2.0 personal video asset. Uploads the input video and "
|
||||
"registers it in the given asset group. If group_id is empty, runs a real-person "
|
||||
|
||||
@ -144,7 +144,7 @@ class ByteDanceSeedNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedNode",
|
||||
display_name="ByteDance Seed",
|
||||
category="text/partner/ByteDance",
|
||||
category="partner/text/ByteDance",
|
||||
essentials_category="Text Generation",
|
||||
description="Generate text responses with ByteDance's Seed 2.0 models. "
|
||||
"Provide a text prompt and optionally one or more images or videos for multimodal context.",
|
||||
|
||||
@ -69,7 +69,7 @@ class ElevenLabsSpeechToText(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsSpeechToText",
|
||||
display_name="ElevenLabs Speech to Text",
|
||||
category="audio/partner/ElevenLabs",
|
||||
category="partner/audio/ElevenLabs",
|
||||
description="Transcribe audio to text. "
|
||||
"Supports automatic language detection, speaker diarization, and audio event tagging.",
|
||||
inputs=[
|
||||
@ -210,7 +210,7 @@ class ElevenLabsVoiceSelector(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsVoiceSelector",
|
||||
display_name="ElevenLabs Voice Selector",
|
||||
category="audio/partner/ElevenLabs",
|
||||
category="partner/audio/ElevenLabs",
|
||||
description="Select a predefined ElevenLabs voice for text-to-speech generation.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -239,7 +239,7 @@ class ElevenLabsTextToSpeech(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsTextToSpeech",
|
||||
display_name="ElevenLabs Text to Speech",
|
||||
category="audio/partner/ElevenLabs",
|
||||
category="partner/audio/ElevenLabs",
|
||||
description="Convert text to speech.",
|
||||
inputs=[
|
||||
IO.Custom(ELEVENLABS_VOICE).Input(
|
||||
@ -414,7 +414,7 @@ class ElevenLabsAudioIsolation(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsAudioIsolation",
|
||||
display_name="ElevenLabs Voice Isolation",
|
||||
category="audio/partner/ElevenLabs",
|
||||
category="partner/audio/ElevenLabs",
|
||||
description="Remove background noise from audio, isolating vocals or speech.",
|
||||
inputs=[
|
||||
IO.Audio.Input(
|
||||
@ -459,7 +459,7 @@ class ElevenLabsTextToSoundEffects(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsTextToSoundEffects",
|
||||
display_name="ElevenLabs Text to Sound Effects",
|
||||
category="audio/partner/ElevenLabs",
|
||||
category="partner/audio/ElevenLabs",
|
||||
description="Generate sound effects from text descriptions.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -555,7 +555,7 @@ class ElevenLabsInstantVoiceClone(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsInstantVoiceClone",
|
||||
display_name="ElevenLabs Instant Voice Clone",
|
||||
category="audio/partner/ElevenLabs",
|
||||
category="partner/audio/ElevenLabs",
|
||||
description="Create a cloned voice from audio samples. "
|
||||
"Provide 1-8 audio recordings of the voice to clone.",
|
||||
inputs=[
|
||||
@ -658,7 +658,7 @@ class ElevenLabsSpeechToSpeech(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsSpeechToSpeech",
|
||||
display_name="ElevenLabs Speech to Speech",
|
||||
category="audio/partner/ElevenLabs",
|
||||
category="partner/audio/ElevenLabs",
|
||||
description="Transform speech from one voice to another while preserving the original content and emotion.",
|
||||
inputs=[
|
||||
IO.Custom(ELEVENLABS_VOICE).Input(
|
||||
@ -793,7 +793,7 @@ class ElevenLabsTextToDialogue(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsTextToDialogue",
|
||||
display_name="ElevenLabs Text to Dialogue",
|
||||
category="audio/partner/ElevenLabs",
|
||||
category="partner/audio/ElevenLabs",
|
||||
description="Generate multi-speaker dialogue from text. Each dialogue entry has its own text and voice.",
|
||||
inputs=[
|
||||
IO.Float.Input(
|
||||
|
||||
@ -300,7 +300,7 @@ class GeminiNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNode",
|
||||
display_name="Google Gemini",
|
||||
category="text/partner/Gemini",
|
||||
category="partner/text/Gemini",
|
||||
description="Generate text responses with Google's Gemini AI model. "
|
||||
"You can provide multiple types of inputs (text, images, audio, video) "
|
||||
"as context for generating more relevant and meaningful responses.",
|
||||
@ -541,7 +541,7 @@ class GeminiInputFiles(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiInputFiles",
|
||||
display_name="Gemini Input Files",
|
||||
category="text/partner/Gemini",
|
||||
category="partner/text/Gemini",
|
||||
description="Loads and prepares input files to include as inputs for Gemini LLM nodes. "
|
||||
"The files will be read by the Gemini model when generating a response. "
|
||||
"The contents of the text file count toward the token limit. "
|
||||
@ -598,7 +598,7 @@ class GeminiImage(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiImageNode",
|
||||
display_name="Nano Banana (Google Gemini Image)",
|
||||
category="image/partner/Gemini",
|
||||
category="partner/image/Gemini",
|
||||
description="Edit images synchronously via Google API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -731,7 +731,7 @@ class GeminiImage2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiImage2Node",
|
||||
display_name="Nano Banana Pro (Google Gemini Image)",
|
||||
category="image/partner/Gemini",
|
||||
category="partner/image/Gemini",
|
||||
description="Generate or edit images synchronously via Google Vertex API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -869,7 +869,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNanoBanana2",
|
||||
display_name="Nano Banana 2",
|
||||
category="image/partner/Gemini",
|
||||
category="partner/image/Gemini",
|
||||
description="Generate or edit images synchronously via Google Vertex API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -1085,7 +1085,7 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNanoBanana2V2",
|
||||
display_name="Nano Banana 2",
|
||||
category="image/partner/Gemini",
|
||||
category="partner/image/Gemini",
|
||||
description="Generate or edit images synchronously via Google Vertex API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
|
||||
@ -29,6 +29,11 @@ from comfy_api_nodes.util import (
|
||||
)
|
||||
|
||||
|
||||
_GROK_VIDEO_MODEL_API_IDS = {
|
||||
"grok-imagine-video-1.5": "grok-imagine-video-1.5-preview",
|
||||
}
|
||||
|
||||
|
||||
def _extract_grok_price(response) -> float | None:
|
||||
if response.usage and response.usage.cost_in_usd_ticks is not None:
|
||||
return response.usage.cost_in_usd_ticks / 10_000_000_000
|
||||
@ -49,7 +54,7 @@ class GrokImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokImageNode",
|
||||
display_name="Grok Image",
|
||||
category="image/partner/Grok",
|
||||
category="partner/image/Grok",
|
||||
description="Generate images using Grok based on a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -223,7 +228,7 @@ class GrokImageEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokImageEditNode",
|
||||
display_name="Grok Image Edit",
|
||||
category="image/partner/Grok",
|
||||
category="partner/image/Grok",
|
||||
description="Modify an existing image based on a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -364,7 +369,7 @@ class GrokImageEditNodeV2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokImageEditNodeV2",
|
||||
display_name="Grok Image Edit",
|
||||
category="image/partner/Grok",
|
||||
category="partner/image/Grok",
|
||||
description="Modify an existing image based on a text prompt",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -501,10 +506,14 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoNode",
|
||||
display_name="Grok Video",
|
||||
category="video/partner/Grok",
|
||||
category="partner/video/Grok",
|
||||
description="Generate video from a prompt or an image",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-video"]),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["grok-imagine-video", "grok-imagine-video-1.5"],
|
||||
tooltip="grok-imagine-video-1.5 currently always requires an input image.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
@ -540,7 +549,11 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
IO.Image.Input("image", optional=True),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
tooltip="Optional starting image for grok-imagine-video. Required for grok-imagine-video-1.5.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@ -552,12 +565,16 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"], inputs=["image"]),
|
||||
expr="""
|
||||
(
|
||||
$rate := widgets.resolution = "720p" ? 0.07 : 0.05;
|
||||
$is15 := $contains(widgets.model, "1.5");
|
||||
$rate := $is15
|
||||
? (widgets.resolution = "720p" ? 0.2002 : 0.1144)
|
||||
: (widgets.resolution = "720p" ? 0.07 : 0.05);
|
||||
$imgCost := $is15 ? 0.0143 : 0.002;
|
||||
$base := $rate * widgets.duration;
|
||||
{"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base}
|
||||
{"type":"usd","usd": inputs.image.connected ? $base + $imgCost : $base}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -574,6 +591,8 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
seed: int,
|
||||
image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if image is None and model == "grok-imagine-video-1.5":
|
||||
raise ValueError(f"The '{model}' model requires an input image; connect one to the 'image' input.")
|
||||
image_url = None
|
||||
if image is not None:
|
||||
if get_number_of_images(image) != 1:
|
||||
@ -584,7 +603,7 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
|
||||
data=VideoGenerationRequest(
|
||||
model=model,
|
||||
model=_GROK_VIDEO_MODEL_API_IDS.get(model, model),
|
||||
image=image_url,
|
||||
prompt=prompt,
|
||||
resolution=resolution,
|
||||
@ -599,7 +618,7 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
price_extractor=_extract_grok_price,
|
||||
price_extractor=_extract_grok_video_price if model == "grok-imagine-video-1.5" else _extract_grok_price,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
@ -611,7 +630,7 @@ class GrokVideoEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoEditNode",
|
||||
display_name="Grok Video Edit",
|
||||
category="video/partner/Grok",
|
||||
category="partner/video/Grok",
|
||||
description="Edit an existing video based on a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-video"]),
|
||||
@ -689,7 +708,7 @@ class GrokVideoReferenceNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoReferenceNode",
|
||||
display_name="Grok Reference-to-Video",
|
||||
category="video/partner/Grok",
|
||||
category="partner/video/Grok",
|
||||
description="Generate video guided by reference images as style and content references.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -822,7 +841,7 @@ class GrokVideoExtendNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoExtendNode",
|
||||
display_name="Grok Video Extend",
|
||||
category="video/partner/Grok",
|
||||
category="partner/video/Grok",
|
||||
description="Extend an existing video with a seamless continuation based on a text prompt.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
|
||||
@ -71,7 +71,7 @@ class HitPawGeneralImageEnhance(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HitPawGeneralImageEnhance",
|
||||
display_name="HitPaw General Image Enhance",
|
||||
category="image/partner/HitPaw",
|
||||
category="partner/image/HitPaw",
|
||||
description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. "
|
||||
f"Maximum output: {MAX_MP_GENERATIVE} megapixels.",
|
||||
inputs=[
|
||||
@ -201,7 +201,7 @@ class HitPawVideoEnhance(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HitPawVideoEnhance",
|
||||
display_name="HitPaw Video Enhance",
|
||||
category="video/partner/HitPaw",
|
||||
category="partner/video/HitPaw",
|
||||
description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. "
|
||||
"Prices shown are per second of video.",
|
||||
inputs=[
|
||||
|
||||
@ -123,7 +123,7 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TencentTextToModelNode",
|
||||
display_name="Hunyuan3D: Text to Model",
|
||||
category="3d/partner/Tencent",
|
||||
category="partner/3d/Tencent",
|
||||
essentials_category="3D",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -242,7 +242,7 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TencentImageToModelNode",
|
||||
display_name="Hunyuan3D: Image(s) to Model",
|
||||
category="3d/partner/Tencent",
|
||||
category="partner/3d/Tencent",
|
||||
essentials_category="3D",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -415,7 +415,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TencentModelTo3DUVNode",
|
||||
display_name="Hunyuan3D: Model to UV",
|
||||
category="3d/partner/Tencent",
|
||||
category="partner/3d/Tencent",
|
||||
description="Perform UV unfolding on a 3D model to generate UV texture. "
|
||||
"Input model must have less than 30000 faces.",
|
||||
inputs=[
|
||||
@ -505,7 +505,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Tencent3DTextureEditNode",
|
||||
display_name="Hunyuan3D: 3D Texture Edit",
|
||||
category="3d/partner/Tencent",
|
||||
category="partner/3d/Tencent",
|
||||
description="After inputting the 3D model, perform 3D model texture redrawing.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
@ -594,7 +594,7 @@ class Tencent3DPartNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Tencent3DPartNode",
|
||||
display_name="Hunyuan3D: 3D Part",
|
||||
category="3d/partner/Tencent",
|
||||
category="partner/3d/Tencent",
|
||||
description="Automatically perform component identification and generation based on the model structure.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
@ -666,7 +666,7 @@ class TencentSmartTopologyNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TencentSmartTopologyNode",
|
||||
display_name="Hunyuan3D: Smart Topology",
|
||||
category="3d/partner/Tencent",
|
||||
category="partner/3d/Tencent",
|
||||
description="Perform smart retopology on a 3D model. "
|
||||
"Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.",
|
||||
inputs=[
|
||||
|
||||
@ -234,7 +234,7 @@ class IdeogramV1(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV1",
|
||||
display_name="Ideogram V1",
|
||||
category="image/partner/Ideogram",
|
||||
category="partner/image/Ideogram",
|
||||
description="Generates images using the Ideogram V1 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -360,7 +360,7 @@ class IdeogramV2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV2",
|
||||
display_name="Ideogram V2",
|
||||
category="image/partner/Ideogram",
|
||||
category="partner/image/Ideogram",
|
||||
description="Generates images using the Ideogram V2 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -526,7 +526,7 @@ class IdeogramV3(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV3",
|
||||
display_name="Ideogram V3",
|
||||
category="image/partner/Ideogram",
|
||||
category="partner/image/Ideogram",
|
||||
description="Generates images using the Ideogram V3 model. "
|
||||
"Supports both regular image generation from text prompts and image editing with mask.",
|
||||
inputs=[
|
||||
|
||||
@ -642,7 +642,7 @@ class KlingCameraControls(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingCameraControls",
|
||||
display_name="Kling Camera Controls",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Allows specifying configuration options for Kling Camera Controls and motion control effects.",
|
||||
inputs=[
|
||||
IO.Combo.Input("camera_control_type", options=KlingCameraControlType),
|
||||
@ -762,7 +762,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingTextToVideoNode",
|
||||
display_name="Kling Text to Video",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Kling Text to Video Node",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
@ -849,7 +849,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProTextToVideoNode",
|
||||
display_name="Kling 3.0 Omni Text to Video",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Use text prompts to generate videos with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||
@ -998,7 +998,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProFirstLastFrameNode",
|
||||
display_name="Kling 3.0 Omni First-Last-Frame to Video",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||
@ -1205,7 +1205,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProImageToVideoNode",
|
||||
display_name="Kling 3.0 Omni Image to Video",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Use up to 7 reference images to generate a video with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||
@ -1374,7 +1374,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProVideoToVideoNode",
|
||||
display_name="Kling 3.0 Omni Video to Video",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||
@ -1485,7 +1485,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProEditVideoNode",
|
||||
display_name="Kling 3.0 Omni Edit Video",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
essentials_category="Video Generation",
|
||||
description="Edit an existing video with the latest model from Kling.",
|
||||
inputs=[
|
||||
@ -1593,7 +1593,7 @@ class OmniProImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProImageNode",
|
||||
display_name="Kling 3.0 Omni Image",
|
||||
category="image/partner/Kling",
|
||||
category="partner/image/Kling",
|
||||
description="Create or edit images with the latest model from Kling.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]),
|
||||
@ -1721,7 +1721,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingCameraControlT2VNode",
|
||||
display_name="Kling Text to Video (Camera Control)",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
@ -1783,7 +1783,7 @@ class KlingImage2VideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingImage2VideoNode",
|
||||
display_name="Kling Image(First Frame) to Video",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
inputs=[
|
||||
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
@ -1882,7 +1882,7 @@ class KlingCameraControlI2VNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingCameraControlI2VNode",
|
||||
display_name="Kling Image to Video (Camera Control)",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -1953,7 +1953,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingStartEndFrameNode",
|
||||
display_name="Kling Start-End Frame to Video",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -2047,7 +2047,7 @@ class KlingVideoExtendNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingVideoExtendNode",
|
||||
display_name="Kling Video Extend",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -2128,7 +2128,7 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingDualCharacterVideoEffectNode",
|
||||
display_name="Kling Dual Character Video Effects",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite.",
|
||||
inputs=[
|
||||
IO.Image.Input("image_left", tooltip="Left side image"),
|
||||
@ -2218,7 +2218,7 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingSingleImageVideoEffectNode",
|
||||
display_name="Kling Video Effects",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Achieve different special effects when generating a video based on the effect_scene.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -2291,7 +2291,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingLipSyncAudioToVideoNode",
|
||||
display_name="Kling Lip Sync Video with Audio",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
essentials_category="Video Generation",
|
||||
description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
|
||||
inputs=[
|
||||
@ -2343,7 +2343,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingLipSyncTextToVideoNode",
|
||||
display_name="Kling Lip Sync Video with Text",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
@ -2411,7 +2411,7 @@ class KlingVirtualTryOnNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingVirtualTryOnNode",
|
||||
display_name="Kling Virtual Try On",
|
||||
category="image/partner/Kling",
|
||||
category="partner/image/Kling",
|
||||
description="Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background.",
|
||||
inputs=[
|
||||
IO.Image.Input("human_image"),
|
||||
@ -2478,7 +2478,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingImageGenerationNode",
|
||||
display_name="Kling 3.0 Image",
|
||||
category="image/partner/Kling",
|
||||
category="partner/image/Kling",
|
||||
description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
@ -2615,7 +2615,7 @@ class TextToVideoWithAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingTextToVideoWithAudio",
|
||||
display_name="Kling 2.6 Text to Video with Audio",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
|
||||
@ -2683,7 +2683,7 @@ class ImageToVideoWithAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingImageToVideoWithAudio",
|
||||
display_name="Kling 2.6 Image(First Frame) to Video with Audio",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
||||
IO.Image.Input("start_frame"),
|
||||
@ -2753,7 +2753,7 @@ class MotionControl(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingMotionControl",
|
||||
display_name="Kling Motion Control",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True),
|
||||
IO.Image.Input("reference_image"),
|
||||
@ -2854,7 +2854,7 @@ class KlingVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingVideoNode",
|
||||
display_name="Kling 3.0 Video",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Generate videos with Kling V3. "
|
||||
"Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.",
|
||||
inputs=[
|
||||
@ -3077,7 +3077,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingFirstLastFrameNode",
|
||||
display_name="Kling 3.0 First-Last-Frame to Video",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Generate videos with Kling V3 using first and last frames.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
@ -3202,7 +3202,7 @@ class KlingAvatarNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingAvatarNode",
|
||||
display_name="Kling Avatar 2.0",
|
||||
category="video/partner/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Generate broadcast-style digital human videos from a single photo and an audio file.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
|
||||
@ -106,7 +106,7 @@ class Krea2ImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Krea2ImageNode",
|
||||
display_name="Krea 2 Image",
|
||||
category="image/partner/Krea",
|
||||
category="partner/image/Krea",
|
||||
description=(
|
||||
"Generate images via Krea 2 — pick Medium (expressive illustrations) or "
|
||||
"Large (expressive photorealism). Supports an optional moodboard and up "
|
||||
@ -229,7 +229,7 @@ class Krea2StyleReferenceNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Krea2StyleReferenceNode",
|
||||
display_name="Krea 2 Style Reference",
|
||||
category="image/partner/Krea",
|
||||
category="partner/image/Krea",
|
||||
description=(
|
||||
"Add an image style reference to a Krea 2 generation. Chain multiple Krea 2 "
|
||||
"Style Reference nodes (max 10) and feed the final `style_reference` output "
|
||||
|
||||
@ -50,7 +50,7 @@ class TextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LtxvApiTextToVideo",
|
||||
display_name="LTXV Text To Video",
|
||||
category="video/partner/LTXV",
|
||||
category="partner/video/LTXV",
|
||||
description="Professional-quality videos with customizable duration and resolution.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
|
||||
@ -127,7 +127,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LtxvApiImageToVideo",
|
||||
display_name="LTXV Image To Video",
|
||||
category="video/partner/LTXV",
|
||||
category="partner/video/LTXV",
|
||||
description="Professional-quality videos with customizable duration and resolution based on start image.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="First frame to be used for the video."),
|
||||
|
||||
@ -46,7 +46,7 @@ class LumaReferenceNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaReferenceNode",
|
||||
display_name="Luma Reference",
|
||||
category="image/partner/Luma",
|
||||
category="partner/image/Luma",
|
||||
description="Holds an image and weight for use with Luma Generate Image node.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -85,7 +85,7 @@ class LumaConceptsNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaConceptsNode",
|
||||
display_name="Luma Concepts",
|
||||
category="video/partner/Luma",
|
||||
category="partner/video/Luma",
|
||||
description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -134,7 +134,7 @@ class LumaImageGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaImageNode",
|
||||
display_name="Luma Text to Image",
|
||||
category="image/partner/Luma",
|
||||
category="partner/image/Luma",
|
||||
description="Generates images synchronously based on prompt and aspect ratio.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -278,7 +278,7 @@ class LumaImageModifyNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaImageModifyNode",
|
||||
display_name="Luma Image to Image",
|
||||
category="image/partner/Luma",
|
||||
category="partner/image/Luma",
|
||||
description="Modifies images synchronously based on prompt and aspect ratio.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -371,7 +371,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaVideoNode",
|
||||
display_name="Luma Text to Video",
|
||||
category="video/partner/Luma",
|
||||
category="partner/video/Luma",
|
||||
description="Generates videos synchronously based on prompt and output_size.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -472,7 +472,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaImageToVideoNode",
|
||||
display_name="Luma Image to Video",
|
||||
category="video/partner/Luma",
|
||||
category="partner/video/Luma",
|
||||
description="Generates videos synchronously based on prompt, input images, and output_size.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -724,7 +724,7 @@ class LumaImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaImageNode2",
|
||||
display_name="Luma UNI-1 Image",
|
||||
category="image/partner/Luma",
|
||||
category="partner/image/Luma",
|
||||
description="Generate images from text using the Luma UNI-1 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -853,7 +853,7 @@ class LumaImageEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaImageEditNode2",
|
||||
display_name="Luma UNI-1 Image Edit",
|
||||
category="image/partner/Luma",
|
||||
category="partner/image/Luma",
|
||||
description="Edit an existing image with a text prompt using the Luma UNI-1 model.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
|
||||
@ -61,7 +61,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageUpscalerCreativeNode",
|
||||
display_name="Magnific Image Upscale (Creative)",
|
||||
category="image/partner/Magnific",
|
||||
category="partner/image/Magnific",
|
||||
description="Prompt‑guided enhancement, stylization, and 2x/4x/8x/16x upscaling. "
|
||||
"Maximum output: 25.3 megapixels.",
|
||||
inputs=[
|
||||
@ -240,7 +240,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageUpscalerPreciseV2Node",
|
||||
display_name="Magnific Image Upscale (Precise V2)",
|
||||
category="image/partner/Magnific",
|
||||
category="partner/image/Magnific",
|
||||
description="High-fidelity upscaling with fine control over sharpness, grain, and detail. "
|
||||
"Maximum output: 10060×10060 pixels.",
|
||||
inputs=[
|
||||
@ -400,7 +400,7 @@ class MagnificImageStyleTransferNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageStyleTransferNode",
|
||||
display_name="Magnific Image Style Transfer",
|
||||
category="image/partner/Magnific",
|
||||
category="partner/image/Magnific",
|
||||
description="Transfer the style from a reference image to your input image.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to apply style transfer to."),
|
||||
@ -549,7 +549,7 @@ class MagnificImageRelightNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageRelightNode",
|
||||
display_name="Magnific Image Relight",
|
||||
category="image/partner/Magnific",
|
||||
category="partner/image/Magnific",
|
||||
description="Relight an image with lighting adjustments and optional reference-based light transfer.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to relight."),
|
||||
@ -789,7 +789,7 @@ class MagnificImageSkinEnhancerNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageSkinEnhancerNode",
|
||||
display_name="Magnific Image Skin Enhancer",
|
||||
category="image/partner/Magnific",
|
||||
category="partner/image/Magnific",
|
||||
description="Skin enhancement for portraits with multiple processing modes.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The portrait image to enhance."),
|
||||
|
||||
@ -33,7 +33,7 @@ class MeshyTextToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyTextToModelNode",
|
||||
display_name="Meshy: Text to Model",
|
||||
category="3d/partner/Meshy",
|
||||
category="partner/3d/Meshy",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
@ -145,7 +145,7 @@ class MeshyRefineNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyRefineNode",
|
||||
display_name="Meshy: Refine Draft Model",
|
||||
category="3d/partner/Meshy",
|
||||
category="partner/3d/Meshy",
|
||||
description="Refine a previously created draft model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
@ -240,7 +240,7 @@ class MeshyImageToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyImageToModelNode",
|
||||
display_name="Meshy: Image to Model",
|
||||
category="3d/partner/Meshy",
|
||||
category="partner/3d/Meshy",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.Image.Input("image"),
|
||||
@ -405,7 +405,7 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyMultiImageToModelNode",
|
||||
display_name="Meshy: Multi-Image to Model",
|
||||
category="3d/partner/Meshy",
|
||||
category="partner/3d/Meshy",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.Autogrow.Input(
|
||||
@ -575,7 +575,7 @@ class MeshyRigModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyRigModelNode",
|
||||
display_name="Meshy: Rig Model",
|
||||
category="3d/partner/Meshy",
|
||||
category="partner/3d/Meshy",
|
||||
description="Provides a rigged character in standard formats. "
|
||||
"Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, "
|
||||
"or humanoid assets with unclear limb and body structure.",
|
||||
@ -656,7 +656,7 @@ class MeshyAnimateModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyAnimateModelNode",
|
||||
display_name="Meshy: Animate Model",
|
||||
category="3d/partner/Meshy",
|
||||
category="partner/3d/Meshy",
|
||||
description="Apply a specific animation action to a previously rigged character.",
|
||||
inputs=[
|
||||
IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"),
|
||||
@ -722,7 +722,7 @@ class MeshyTextureNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyTextureNode",
|
||||
display_name="Meshy: Texture Model",
|
||||
category="3d/partner/Meshy",
|
||||
category="partner/3d/Meshy",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"),
|
||||
|
||||
@ -101,7 +101,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MinimaxTextToVideoNode",
|
||||
display_name="MiniMax Text to Video",
|
||||
category="video/partner/MiniMax",
|
||||
category="partner/video/MiniMax",
|
||||
description="Generates videos synchronously based on a prompt, and optional parameters.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -163,7 +163,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MinimaxImageToVideoNode",
|
||||
display_name="MiniMax Image to Video",
|
||||
category="video/partner/MiniMax",
|
||||
category="partner/video/MiniMax",
|
||||
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -230,7 +230,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MinimaxSubjectToVideoNode",
|
||||
display_name="MiniMax Subject to Video",
|
||||
category="video/partner/MiniMax",
|
||||
category="partner/video/MiniMax",
|
||||
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -294,7 +294,7 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MinimaxHailuoVideoNode",
|
||||
display_name="MiniMax Hailuo Video",
|
||||
category="video/partner/MiniMax",
|
||||
category="partner/video/MiniMax",
|
||||
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
|
||||
@ -99,7 +99,7 @@ class OpenAIDalle2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIDalle2",
|
||||
display_name="OpenAI DALL·E 2",
|
||||
category="image/partner/OpenAI",
|
||||
category="partner/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -249,7 +249,7 @@ class OpenAIDalle3(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIDalle3",
|
||||
display_name="OpenAI DALL·E 3",
|
||||
category="image/partner/OpenAI",
|
||||
category="partner/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -371,7 +371,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIGPTImage1",
|
||||
display_name="OpenAI GPT Image 2",
|
||||
category="image/partner/OpenAI",
|
||||
category="partner/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
|
||||
is_deprecated=True,
|
||||
inputs=[
|
||||
@ -695,7 +695,7 @@ class OpenAIGPTImageNodeV2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIGPTImageNodeV2",
|
||||
display_name="OpenAI GPT Image 2",
|
||||
category="image/partner/OpenAI",
|
||||
category="partner/image/OpenAI",
|
||||
description="Generates images via OpenAI's GPT Image endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -962,7 +962,7 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIChatNode",
|
||||
display_name="OpenAI ChatGPT",
|
||||
category="text/partner/OpenAI",
|
||||
category="partner/text/OpenAI",
|
||||
essentials_category="Text Generation",
|
||||
description="Generate text responses from an OpenAI model.",
|
||||
inputs=[
|
||||
@ -1201,7 +1201,7 @@ class OpenAIInputFiles(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIInputFiles",
|
||||
display_name="OpenAI ChatGPT Input Files",
|
||||
category="text/partner/OpenAI",
|
||||
category="partner/text/OpenAI",
|
||||
description="Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -1248,7 +1248,7 @@ class OpenAIChatConfig(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIChatConfig",
|
||||
display_name="OpenAI ChatGPT Advanced Options",
|
||||
category="text/partner/OpenAI",
|
||||
category="partner/text/OpenAI",
|
||||
description="Allows specifying advanced configuration options for the OpenAI Chat Nodes.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
|
||||
@ -265,7 +265,7 @@ class OpenRouterLLMNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenRouterLLMNode",
|
||||
display_name="OpenRouter LLM",
|
||||
category="text/partner/OpenRouter",
|
||||
category="partner/text/OpenRouter",
|
||||
essentials_category="Text Generation",
|
||||
description=(
|
||||
"Generate text responses through OpenRouter. Routes to a curated set of popular "
|
||||
|
||||
@ -53,7 +53,7 @@ class PixverseTemplateNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="PixverseTemplateNode",
|
||||
display_name="PixVerse Template",
|
||||
category="video/partner/PixVerse",
|
||||
category="partner/video/PixVerse",
|
||||
inputs=[
|
||||
IO.Combo.Input("template", options=list(pixverse_templates.keys())),
|
||||
],
|
||||
@ -74,7 +74,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="PixverseTextToVideoNode",
|
||||
display_name="PixVerse Text to Video",
|
||||
category="video/partner/PixVerse",
|
||||
category="partner/video/PixVerse",
|
||||
description="Generates videos based on prompt and output_size.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -192,7 +192,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="PixverseImageToVideoNode",
|
||||
display_name="PixVerse Image to Video",
|
||||
category="video/partner/PixVerse",
|
||||
category="partner/video/PixVerse",
|
||||
description="Generates videos based on prompt and output_size.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -310,7 +310,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="PixverseTransitionVideoNode",
|
||||
display_name="PixVerse Transition Video",
|
||||
category="video/partner/PixVerse",
|
||||
category="partner/video/PixVerse",
|
||||
description="Generates videos based on prompt and output_size.",
|
||||
inputs=[
|
||||
IO.Image.Input("first_frame"),
|
||||
|
||||
@ -62,7 +62,7 @@ class QuiverTextToSVGNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="QuiverTextToSVGNode",
|
||||
display_name="Quiver Text to SVG",
|
||||
category="image/partner/Quiver",
|
||||
category="partner/image/Quiver",
|
||||
description="Generate an SVG from a text prompt using Quiver AI.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -177,7 +177,7 @@ class QuiverImageToSVGNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="QuiverImageToSVGNode",
|
||||
display_name="Quiver Image to SVG",
|
||||
category="image/partner/Quiver",
|
||||
category="partner/image/Quiver",
|
||||
description="Vectorize a raster image into SVG using Quiver AI.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
|
||||
@ -178,7 +178,7 @@ class RecraftColorRGBNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftColorRGB",
|
||||
display_name="Recraft Color RGB",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Create Recraft Color by choosing specific RGB values.",
|
||||
inputs=[
|
||||
IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."),
|
||||
@ -204,7 +204,7 @@ class RecraftControlsNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftControls",
|
||||
display_name="Recraft Controls",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Create Recraft Controls for customizing Recraft generation.",
|
||||
inputs=[
|
||||
IO.Custom(RecraftIO.COLOR).Input("colors", optional=True),
|
||||
@ -228,7 +228,7 @@ class RecraftStyleV3RealisticImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftStyleV3RealisticImage",
|
||||
display_name="Recraft Style - Realistic Image",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Select realistic_image style and optional substyle.",
|
||||
inputs=[
|
||||
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
|
||||
@ -253,7 +253,7 @@ class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftStyleV3DigitalIllustration",
|
||||
display_name="Recraft Style - Digital Illustration",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Select realistic_image style and optional substyle.",
|
||||
inputs=[
|
||||
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
|
||||
@ -272,7 +272,7 @@ class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftStyleV3VectorIllustrationNode",
|
||||
display_name="Recraft Style - Realistic Image",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Select realistic_image style and optional substyle.",
|
||||
inputs=[
|
||||
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
|
||||
@ -291,7 +291,7 @@ class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftStyleV3LogoRaster",
|
||||
display_name="Recraft Style - Logo Raster",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Select realistic_image style and optional substyle.",
|
||||
inputs=[
|
||||
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)),
|
||||
@ -308,7 +308,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftStyleV3InfiniteStyleLibrary",
|
||||
display_name="Recraft Style - Infinite Style Library",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.",
|
||||
inputs=[
|
||||
IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."),
|
||||
@ -331,7 +331,7 @@ class RecraftCreateStyleNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftCreateStyleNode",
|
||||
display_name="Recraft Create Style",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Create a custom style from reference images. "
|
||||
"Upload 1-5 images to use as style references. "
|
||||
"Total size of all images is limited to 5 MB.",
|
||||
@ -400,7 +400,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftTextToImageNode",
|
||||
display_name="Recraft Text to Image",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Generates images synchronously based on prompt and resolution.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."),
|
||||
@ -512,7 +512,7 @@ class RecraftImageToImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftImageToImageNode",
|
||||
display_name="Recraft Image to Image",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Modify image based on prompt and strength.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -630,7 +630,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftImageInpaintingNode",
|
||||
display_name="Recraft Image Inpainting",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Modify image based on prompt and mask.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -732,7 +732,7 @@ class RecraftTextToVectorNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftTextToVectorNode",
|
||||
display_name="Recraft Text to Vector",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Generates SVG synchronously based on prompt and resolution.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True),
|
||||
@ -832,7 +832,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftVectorizeImageNode",
|
||||
display_name="Recraft Vectorize Image",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
essentials_category="Image Tools",
|
||||
description="Generates SVG synchronously from an input image.",
|
||||
inputs=[
|
||||
@ -876,7 +876,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftReplaceBackgroundNode",
|
||||
display_name="Recraft Replace Background",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Replace background on image, based on provided prompt.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -963,7 +963,7 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftRemoveBackgroundNode",
|
||||
display_name="Recraft Remove Background",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
essentials_category="Image Tools",
|
||||
description="Remove background from image, and return processed image and mask.",
|
||||
inputs=[
|
||||
@ -1012,7 +1012,7 @@ class RecraftCrispUpscaleNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftCrispUpscaleNode",
|
||||
display_name="Recraft Crisp Upscale Image",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Upscale image synchronously.\n"
|
||||
"Enhances a given raster image using ‘crisp upscale’ tool, "
|
||||
"increasing image resolution, making the image sharper and cleaner.",
|
||||
@ -1058,7 +1058,7 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftCreativeUpscaleNode",
|
||||
display_name="Recraft Creative Upscale Image",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Upscale image synchronously.\n"
|
||||
"Enhances a given raster image using ‘creative upscale’ tool, "
|
||||
"boosting resolution with a focus on refining small details and faces.",
|
||||
@ -1086,7 +1086,7 @@ class RecraftV4TextToImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftV4TextToImageNode",
|
||||
display_name="Recraft V4 Text to Image",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Generates images using Recraft V4 or V4 Pro models.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -1210,7 +1210,7 @@ class RecraftV4TextToVectorNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftV4TextToVectorNode",
|
||||
display_name="Recraft V4 Text to Vector",
|
||||
category="image/partner/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Generates SVG using Recraft V4 or V4 Pro models.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
|
||||
@ -109,7 +109,7 @@ class ReveImageCreateNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageCreateNode",
|
||||
display_name="Reve Image Create",
|
||||
category="image/partner/Reve",
|
||||
category="partner/image/Reve",
|
||||
description="Generate images from text descriptions using Reve.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -200,7 +200,7 @@ class ReveImageEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageEditNode",
|
||||
display_name="Reve Image Edit",
|
||||
category="image/partner/Reve",
|
||||
category="partner/image/Reve",
|
||||
description="Edit images using natural language instructions with Reve.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to edit."),
|
||||
@ -300,7 +300,7 @@ class ReveImageRemixNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageRemixNode",
|
||||
display_name="Reve Image Remix",
|
||||
category="image/partner/Reve",
|
||||
category="partner/image/Reve",
|
||||
description="Combine reference images with text prompts to create new images using Reve.",
|
||||
inputs=[
|
||||
IO.Autogrow.Input(
|
||||
|
||||
@ -230,7 +230,7 @@ class Rodin3D_Regular(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Regular",
|
||||
display_name="Rodin 3D Generate - Regular Generate",
|
||||
category="3d/partner/Rodin",
|
||||
category="partner/3d/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("Images"),
|
||||
@ -289,7 +289,7 @@ class Rodin3D_Detail(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Detail",
|
||||
display_name="Rodin 3D Generate - Detail Generate",
|
||||
category="3d/partner/Rodin",
|
||||
category="partner/3d/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("Images"),
|
||||
@ -348,7 +348,7 @@ class Rodin3D_Smooth(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Smooth",
|
||||
display_name="Rodin 3D Generate - Smooth Generate",
|
||||
category="3d/partner/Rodin",
|
||||
category="partner/3d/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("Images"),
|
||||
@ -406,7 +406,7 @@ class Rodin3D_Sketch(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Sketch",
|
||||
display_name="Rodin 3D Generate - Sketch Generate",
|
||||
category="3d/partner/Rodin",
|
||||
category="partner/3d/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("Images"),
|
||||
@ -468,7 +468,7 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Gen2",
|
||||
display_name="Rodin 3D Generate - Gen-2 Generate",
|
||||
category="3d/partner/Rodin",
|
||||
category="partner/3d/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("Images"),
|
||||
@ -941,7 +941,7 @@ class Rodin3D_Gen25_Image(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Gen25_Image",
|
||||
display_name="Rodin 3D Gen-2.5 - Image to 3D",
|
||||
category="3d/partner/Rodin",
|
||||
category="partner/3d/Rodin",
|
||||
description=(
|
||||
"Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. "
|
||||
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
|
||||
@ -1035,7 +1035,7 @@ class Rodin3D_Gen25_Text(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Gen25_Text",
|
||||
display_name="Rodin 3D Gen-2.5 - Text to 3D",
|
||||
category="3d/partner/Rodin",
|
||||
category="partner/3d/Rodin",
|
||||
description=(
|
||||
"Generate a 3D model from a text prompt via Rodin Gen-2.5. "
|
||||
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
|
||||
|
||||
@ -140,7 +140,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RunwayImageToVideoNodeGen3a",
|
||||
display_name="Runway Image to Video (Gen3a Turbo)",
|
||||
category="video/partner/Runway",
|
||||
category="partner/video/Runway",
|
||||
description="Generate a video from a single starting frame using Gen3a Turbo model. "
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
@ -234,7 +234,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RunwayImageToVideoNodeGen4",
|
||||
display_name="Runway Image to Video (Gen4 Turbo)",
|
||||
category="video/partner/Runway",
|
||||
category="partner/video/Runway",
|
||||
description="Generate a video from a single starting frame using Gen4 Turbo model. "
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
@ -329,7 +329,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RunwayFirstLastFrameNode",
|
||||
display_name="Runway First-Last-Frame to Video",
|
||||
category="video/partner/Runway",
|
||||
category="partner/video/Runway",
|
||||
description="Upload first and last keyframes, draft a prompt, and generate a video. "
|
||||
"More complex transitions, such as cases where the Last frame is completely different "
|
||||
"from the First frame, may benefit from the longer 10s duration. "
|
||||
@ -440,7 +440,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RunwayTextToImageNode",
|
||||
display_name="Runway Text to Image",
|
||||
category="image/partner/Runway",
|
||||
category="partner/image/Runway",
|
||||
description="Generate an image from a text prompt using Runway's Gen 4 model. "
|
||||
"You can also include reference image to guide the generation.",
|
||||
inputs=[
|
||||
|
||||
@ -34,7 +34,7 @@ class SoniloVideoToMusic(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="SoniloVideoToMusic",
|
||||
display_name="Sonilo Video to Music",
|
||||
category="audio/partner/Sonilo",
|
||||
category="partner/audio/Sonilo",
|
||||
description="Generate music from video content using Sonilo's AI model. "
|
||||
"Analyzes the video and creates matching music.",
|
||||
inputs=[
|
||||
@ -99,7 +99,7 @@ class SoniloTextToMusic(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="SoniloTextToMusic",
|
||||
display_name="Sonilo Text to Music",
|
||||
category="audio/partner/Sonilo",
|
||||
category="partner/audio/Sonilo",
|
||||
description="Generate music from a text prompt using Sonilo's AI model. "
|
||||
"Leave duration at 0 to let the model infer it from the prompt.",
|
||||
inputs=[
|
||||
|
||||
@ -34,7 +34,7 @@ class OpenAIVideoSora2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIVideoSora2",
|
||||
display_name="OpenAI Sora - Video (DEPRECATED)",
|
||||
category="video/partner/Sora",
|
||||
category="partner/video/Sora",
|
||||
description=(
|
||||
"OpenAI video and audio generation.\n\n"
|
||||
"DEPRECATION NOTICE: OpenAI will stop serving the Sora v2 API in September 2026. "
|
||||
|
||||
@ -62,7 +62,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityStableImageUltraNode",
|
||||
display_name="Stability AI Stable Image Ultra",
|
||||
category="image/partner/Stability AI",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -197,7 +197,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityStableImageSD_3_5Node",
|
||||
display_name="Stability AI Stable Diffusion 3.5 Image",
|
||||
category="image/partner/Stability AI",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -354,7 +354,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleConservativeNode",
|
||||
display_name="Stability AI Upscale Conservative",
|
||||
category="image/partner/Stability AI",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -457,7 +457,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleCreativeNode",
|
||||
display_name="Stability AI Upscale Creative",
|
||||
category="image/partner/Stability AI",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -578,7 +578,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleFastNode",
|
||||
display_name="Stability AI Upscale Fast",
|
||||
category="image/partner/Stability AI",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -630,7 +630,7 @@ class StabilityTextToAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityTextToAudio",
|
||||
display_name="Stability AI Text To Audio",
|
||||
category="audio/partner/Stability AI",
|
||||
category="partner/audio/Stability AI",
|
||||
essentials_category="Audio",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
@ -708,7 +708,7 @@ class StabilityAudioToAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityAudioToAudio",
|
||||
display_name="Stability AI Audio To Audio",
|
||||
category="audio/partner/Stability AI",
|
||||
category="partner/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -802,7 +802,7 @@ class StabilityAudioInpaint(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityAudioInpaint",
|
||||
display_name="Stability AI Audio Inpaint",
|
||||
category="audio/partner/Stability AI",
|
||||
category="partner/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
|
||||
@ -52,7 +52,7 @@ class TopazImageEnhance(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TopazImageEnhance",
|
||||
display_name="Topaz Image Enhance",
|
||||
category="image/partner/Topaz",
|
||||
category="partner/image/Topaz",
|
||||
description="Industry-standard upscaling and image enhancement.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["Reimagine"]),
|
||||
@ -235,7 +235,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TopazVideoEnhance",
|
||||
display_name="Topaz Video Enhance (Legacy)",
|
||||
category="video/partner/Topaz",
|
||||
category="partner/video/Topaz",
|
||||
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
@ -475,7 +475,7 @@ class TopazVideoEnhanceV2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TopazVideoEnhanceV2",
|
||||
display_name="Topaz Video Enhance",
|
||||
category="video/partner/Topaz",
|
||||
category="partner/video/Topaz",
|
||||
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
|
||||
@ -83,7 +83,7 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoTextToModelNode",
|
||||
display_name="Tripo: Text to Model",
|
||||
category="3d/partner/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True, optional=True),
|
||||
@ -210,7 +210,7 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoImageToModelNode",
|
||||
display_name="Tripo: Image to Model",
|
||||
category="3d/partner/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input(
|
||||
@ -358,7 +358,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoMultiviewToModelNode",
|
||||
display_name="Tripo: Multiview to Model",
|
||||
category="3d/partner/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Image.Input("image_left", optional=True),
|
||||
@ -518,7 +518,7 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoTextureNode",
|
||||
display_name="Tripo: Texture model",
|
||||
category="3d/partner/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
inputs=[
|
||||
IO.Custom("MODEL_TASK_ID").Input("model_task_id"),
|
||||
IO.Boolean.Input("texture", default=True, optional=True),
|
||||
@ -595,7 +595,7 @@ class TripoRefineNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoRefineNode",
|
||||
display_name="Tripo: Refine Draft model",
|
||||
category="3d/partner/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
description="Refine a draft model created by v1.4 Tripo models only.",
|
||||
inputs=[
|
||||
IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"),
|
||||
@ -635,7 +635,7 @@ class TripoRigNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoRigNode",
|
||||
display_name="Tripo: Rig model",
|
||||
category="3d/partner/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
@ -672,7 +672,7 @@ class TripoRetargetNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoRetargetNode",
|
||||
display_name="Tripo: Retarget rigged model",
|
||||
category="3d/partner/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
inputs=[
|
||||
IO.Custom("RIG_TASK_ID").Input("original_model_task_id"),
|
||||
IO.Combo.Input(
|
||||
@ -737,7 +737,7 @@ class TripoConversionNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoConversionNode",
|
||||
display_name="Tripo: Convert model",
|
||||
category="3d/partner/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
inputs=[
|
||||
IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"),
|
||||
IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]),
|
||||
@ -1051,7 +1051,7 @@ class TripoP1TextToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoP1TextToModelNode",
|
||||
display_name="Tripo P1: Text to Model",
|
||||
category="3d/partner/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
description="Tripo P1 text-to-3D. Optimized for low-poly, game-ready meshes with stable topology.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Up to 1024 characters."),
|
||||
@ -1122,7 +1122,7 @@ class TripoP1ImageToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoP1ImageToModelNode",
|
||||
display_name="Tripo P1: Image to Model",
|
||||
category="3d/partner/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
description="Tripo P1 image-to-3D. Optimized for low-poly, game-ready meshes.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -1202,7 +1202,7 @@ class TripoP1MultiviewToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoP1MultiviewToModelNode",
|
||||
display_name="Tripo P1: Multiview to Model",
|
||||
category="3d/partner/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
description="Tripo P1 multiview-to-3D from 2-4 reference images in [front, left, back, right] order. "
|
||||
"Front is required; any combination of the other three may be omitted.",
|
||||
inputs=[
|
||||
|
||||
@ -45,7 +45,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="VeoVideoGenerationNode",
|
||||
display_name="Google Veo 2 Video Generation",
|
||||
category="video/partner/Veo",
|
||||
category="partner/video/Veo",
|
||||
description="Generates videos from text prompts using Google's Veo 2 API",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -256,7 +256,7 @@ class Veo3VideoGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Veo3VideoGenerationNode",
|
||||
display_name="Google Veo 3 Video Generation",
|
||||
category="video/partner/Veo",
|
||||
category="partner/video/Veo",
|
||||
description="Generates videos from text prompts using Google's Veo 3 API",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -468,7 +468,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Veo3FirstLastFrameNode",
|
||||
display_name="Google Veo 3 First-Last-Frame to Video",
|
||||
category="video/partner/Veo",
|
||||
category="partner/video/Veo",
|
||||
description="Generate video using prompt and first and last frames.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
|
||||
@ -71,7 +71,7 @@ class ViduTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduTextToVideoNode",
|
||||
display_name="Vidu Text To Video Generation",
|
||||
category="video/partner/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate video from a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
@ -169,7 +169,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduImageToVideoNode",
|
||||
display_name="Vidu Image To Video Generation",
|
||||
category="video/partner/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate video from image and optional prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
@ -273,7 +273,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduReferenceVideoNode",
|
||||
display_name="Vidu Reference To Video Generation",
|
||||
category="video/partner/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate video from multiple images and a prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
@ -388,7 +388,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduStartEndToVideoNode",
|
||||
display_name="Vidu Start End To Video Generation",
|
||||
category="video/partner/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate a video from start and end frames and a prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
@ -492,7 +492,7 @@ class Vidu2TextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2TextToVideoNode",
|
||||
display_name="Vidu2 Text-to-Video Generation",
|
||||
category="video/partner/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate video from a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2"]),
|
||||
@ -584,7 +584,7 @@ class Vidu2ImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2ImageToVideoNode",
|
||||
display_name="Vidu2 Image-to-Video Generation",
|
||||
category="video/partner/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate a video from an image and an optional prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
|
||||
@ -714,7 +714,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2ReferenceVideoNode",
|
||||
display_name="Vidu2 Reference-to-Video Generation",
|
||||
category="video/partner/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate a video from multiple reference images and a prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2"]),
|
||||
@ -849,7 +849,7 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2StartEndToVideoNode",
|
||||
display_name="Vidu2 Start/End Frame-to-Video Generation",
|
||||
category="video/partner/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate a video from a start frame, an end frame, and a prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
|
||||
@ -969,7 +969,7 @@ class ViduExtendVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduExtendVideoNode",
|
||||
display_name="Vidu Video Extension",
|
||||
category="video/partner/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Extend an existing video by generating additional frames.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1138,7 +1138,7 @@ class ViduMultiFrameVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduMultiFrameVideoNode",
|
||||
display_name="Vidu Multi-Frame Video Generation",
|
||||
category="video/partner/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate a video with multiple keyframe transitions.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2-pro", "viduq2-turbo"]),
|
||||
@ -1284,7 +1284,7 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu3TextToVideoNode",
|
||||
display_name="Vidu Q3 Text-to-Video Generation",
|
||||
category="video/partner/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate video from a text prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1429,7 +1429,7 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu3ImageToVideoNode",
|
||||
display_name="Vidu Q3 Image-to-Video Generation",
|
||||
category="video/partner/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate a video from an image and an optional prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1571,7 +1571,7 @@ class Vidu3StartEndToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu3StartEndToVideoNode",
|
||||
display_name="Vidu Q3 Start/End Frame-to-Video Generation",
|
||||
category="video/partner/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate a video from a start frame, an end frame, and a prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
|
||||
@ -61,7 +61,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WanTextToImageApi",
|
||||
display_name="Wan Text to Image",
|
||||
category="image/partner/Wan",
|
||||
category="partner/image/Wan",
|
||||
description="Generates an image based on a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -184,7 +184,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WanImageToImageApi",
|
||||
display_name="Wan Image to Image",
|
||||
category="image/partner/Wan",
|
||||
category="partner/image/Wan",
|
||||
description="Generates an image from one or two input images and a text prompt. "
|
||||
"The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).",
|
||||
inputs=[
|
||||
@ -312,7 +312,7 @@ class WanTextToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WanTextToVideoApi",
|
||||
display_name="Wan Text to Video",
|
||||
category="video/partner/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Generates a video based on a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -495,7 +495,7 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WanImageToVideoApi",
|
||||
display_name="Wan Image to Video",
|
||||
category="video/partner/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Generates a video from the first frame and a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -674,7 +674,7 @@ class WanReferenceVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WanReferenceVideoApi",
|
||||
display_name="Wan Reference to Video",
|
||||
category="video/partner/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Use the character and voice from input videos, combined with a prompt, "
|
||||
"to generate a new video that maintains character consistency.",
|
||||
inputs=[
|
||||
@ -828,7 +828,7 @@ class Wan2TextToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Wan2TextToVideoApi",
|
||||
display_name="Wan 2.7 Text to Video",
|
||||
category="video/partner/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Generates a video based on a text prompt using the Wan 2.7 model.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -981,7 +981,7 @@ class Wan2ImageToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Wan2ImageToVideoApi",
|
||||
display_name="Wan 2.7 Image to Video",
|
||||
category="video/partner/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Generate a video from a first-frame image, with optional last-frame image and audio.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1152,7 +1152,7 @@ class Wan2VideoContinuationApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Wan2VideoContinuationApi",
|
||||
display_name="Wan 2.7 Video Continuation",
|
||||
category="video/partner/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Continue a video from where it left off, with optional last-frame control.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1319,7 +1319,7 @@ class Wan2VideoEditApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Wan2VideoEditApi",
|
||||
display_name="Wan 2.7 Video Edit",
|
||||
category="video/partner/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Edit a video using text instructions, reference images, or style transfer.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1477,7 +1477,7 @@ class Wan2ReferenceVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Wan2ReferenceVideoApi",
|
||||
display_name="Wan 2.7 Reference to Video",
|
||||
category="video/partner/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Generate a video featuring a person or object from reference materials. "
|
||||
"Supports single-character performances and multi-character interactions.",
|
||||
inputs=[
|
||||
@ -1651,7 +1651,7 @@ class HappyHorseTextToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HappyHorseTextToVideoApi",
|
||||
display_name="HappyHorse Text to Video",
|
||||
category="video/partner/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Generates a video based on a text prompt using the HappyHorse model.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1775,7 +1775,7 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HappyHorseImageToVideoApi",
|
||||
display_name="HappyHorse Image to Video",
|
||||
category="video/partner/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Generate a video from a first-frame image using the HappyHorse model.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1905,7 +1905,7 @@ class HappyHorseVideoEditApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HappyHorseVideoEditApi",
|
||||
display_name="HappyHorse Video Edit",
|
||||
category="video/partner/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Edit a video using text instructions or reference images with the HappyHorse model. "
|
||||
"Output duration is 3-15s and matches the input video; inputs longer than 15s are truncated.",
|
||||
inputs=[
|
||||
@ -2046,7 +2046,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HappyHorseReferenceVideoApi",
|
||||
display_name="HappyHorse Reference to Video",
|
||||
category="video/partner/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Generate a video featuring a person or object from reference materials with the HappyHorse "
|
||||
"model. Supports single-character performances and multi-character interactions.",
|
||||
inputs=[
|
||||
|
||||
@ -27,7 +27,7 @@ class WavespeedFlashVSRNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WavespeedFlashVSRNode",
|
||||
display_name="FlashVSR Video Upscale",
|
||||
category="video/partner/WaveSpeed",
|
||||
category="partner/video/WaveSpeed",
|
||||
description="Fast, high-quality video upscaler that "
|
||||
"boosts resolution and restores clarity for low-resolution or blurry footage.",
|
||||
inputs=[
|
||||
@ -98,7 +98,7 @@ class WavespeedImageUpscaleNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WavespeedImageUpscaleNode",
|
||||
display_name="WaveSpeed Image Upscale",
|
||||
category="image/partner/WaveSpeed",
|
||||
category="partner/image/WaveSpeed",
|
||||
description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]),
|
||||
|
||||
@ -469,6 +469,11 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input
|
||||
input_container = None
|
||||
output_container = None
|
||||
|
||||
# get_stream_source() is untrimmed, so apply the trim window in this same pass.
|
||||
# start_time is normalized (>= 0); duration == 0 means "until the end".
|
||||
start_time, duration = video.get_active_trim_window()
|
||||
trimming = bool(start_time or duration)
|
||||
|
||||
try:
|
||||
input_source = video.get_stream_source()
|
||||
input_container = av.open(input_source, mode="r")
|
||||
@ -487,16 +492,45 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input
|
||||
audio_stream.layout = stream.layout
|
||||
break
|
||||
|
||||
in_video = input_container.streams.video[0]
|
||||
start_pts = int(start_time / in_video.time_base) if trimming else 0
|
||||
end_pts = int((start_time + duration) / in_video.time_base) if duration else None
|
||||
if start_pts:
|
||||
input_container.seek(start_pts, stream=in_video)
|
||||
|
||||
encoded = 0
|
||||
for frame in input_container.decode(video=0):
|
||||
if trimming:
|
||||
if frame.pts is None or frame.pts < start_pts:
|
||||
continue
|
||||
if end_pts is not None and frame.pts >= end_pts:
|
||||
break
|
||||
frame = frame.reformat(width=out_w, height=out_h, format="yuv420p")
|
||||
# Re-wrap as a fresh frame: dropping irregular source timestamps (VFR/AVI/GIF/...)
|
||||
# lets the encoder assign clean ones and avoids mp4 muxer errors.
|
||||
frame = av.VideoFrame.from_ndarray(frame.to_ndarray(format="yuv420p"), format="yuv420p")
|
||||
for packet in video_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
encoded += 1
|
||||
for packet in video_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
if encoded == 0:
|
||||
raise ValueError(
|
||||
f"resize produced no frames (start_time={start_time}, duration={duration} "
|
||||
"selected nothing from the source)"
|
||||
)
|
||||
|
||||
if audio_stream is not None:
|
||||
input_container.seek(0)
|
||||
for audio_frame in input_container.decode(audio=0):
|
||||
if trimming:
|
||||
if audio_frame.time is None or audio_frame.time < start_time:
|
||||
continue
|
||||
if duration and audio_frame.time > start_time + duration:
|
||||
break
|
||||
# Carry odd audio time bases the mp4 muxer rejects; reset pts, encoder assigns clean ones (MP3-in-AVI)
|
||||
audio_frame.pts = None
|
||||
for packet in audio_stream.encode(audio_frame):
|
||||
output_container.mux(packet)
|
||||
for packet in audio_stream.encode():
|
||||
|
||||
@ -65,6 +65,12 @@ class ChromaRadianceOptions(io.ComfyNode):
|
||||
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
|
||||
advanced=True,
|
||||
),
|
||||
io.Boolean.Input(
|
||||
id="force_sequential_txt_ids",
|
||||
default=False,
|
||||
tooltip="Force usage of sequential text token IDs instead of zeroes. Should be used for checkpoints from 2026-05-22 to 2026-06-01 that are trained in this way but do not contain the __sequential__ key in the state dict.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
@ -78,11 +84,15 @@ class ChromaRadianceOptions(io.ComfyNode):
|
||||
start_sigma: float,
|
||||
end_sigma: float,
|
||||
nerf_tile_size: int,
|
||||
force_sequential_txt_ids: bool,
|
||||
) -> io.NodeOutput:
|
||||
radiance_options = {}
|
||||
if nerf_tile_size >= 0:
|
||||
radiance_options["nerf_tile_size"] = nerf_tile_size
|
||||
|
||||
if force_sequential_txt_ids:
|
||||
radiance_options["use_sequential_txt_ids"] = True
|
||||
|
||||
if not radiance_options:
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
1664
comfy_extras/nodes_gaussian_splat.py
Normal file
1664
comfy_extras/nodes_gaussian_splat.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -102,11 +102,18 @@ class MathExpressionNode(io.ComfyNode):
|
||||
f"Math Expression '{expression}' must evaluate to a numeric result, "
|
||||
f"got {type(result).__name__}: {result!r}"
|
||||
)
|
||||
if not math.isfinite(result):
|
||||
try:
|
||||
float_result = float(result)
|
||||
except OverflowError:
|
||||
raise ValueError(
|
||||
f"Math Expression '{expression}' produced a result too large to "
|
||||
f"represent as a float: {result}"
|
||||
) from None
|
||||
if not math.isfinite(float_result):
|
||||
raise ValueError(
|
||||
f"Math Expression '{expression}' produced a non-finite result: {result}"
|
||||
)
|
||||
return io.NodeOutput(float(result), int(result), bool(result))
|
||||
return io.NodeOutput(float_result, int(result), bool(result))
|
||||
|
||||
|
||||
class MathExtension(ComfyExtension):
|
||||
|
||||
@ -16,7 +16,7 @@ from comfy.cli_args import args
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
|
||||
|
||||
def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None):
|
||||
def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None, unlit=False):
|
||||
# Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors,
|
||||
# stashing per-item lengths as runtime attrs so consumers can recover the real slice.
|
||||
# colors and uvs are 1:1 with vertices, so they're padded to max_vertices and read with vertex_counts.
|
||||
@ -54,7 +54,7 @@ def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=Non
|
||||
|
||||
return Types.MESH(packed_vertices, packed_faces,
|
||||
uvs=packed_uvs, vertex_colors=packed_colors, texture=texture,
|
||||
vertex_counts=vertex_counts, face_counts=face_counts)
|
||||
vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit)
|
||||
|
||||
|
||||
def get_mesh_batch_item(mesh, index):
|
||||
@ -77,7 +77,7 @@ def get_mesh_batch_item(mesh, index):
|
||||
|
||||
|
||||
def save_glb(vertices, faces, filepath, metadata=None,
|
||||
uvs=None, vertex_colors=None, texture_image=None):
|
||||
uvs=None, vertex_colors=None, texture_image=None, unlit=False):
|
||||
"""
|
||||
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
|
||||
|
||||
@ -234,6 +234,17 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
||||
textures = []
|
||||
samplers = []
|
||||
materials = []
|
||||
extensions_used = []
|
||||
if unlit and texture_png_bytes is None:
|
||||
# Flat, light-independent shading (KHR_materials_unlit): COLOR_0 is shown as-is, matching how a
|
||||
# gaussian splat renders (emissive). Without this the viewer lights the mesh and washes the colours.
|
||||
materials.append({
|
||||
"pbrMetallicRoughness": {"baseColorFactor": [1.0, 1.0, 1.0, 1.0], "metallicFactor": 0.0, "roughnessFactor": 1.0},
|
||||
"extensions": {"KHR_materials_unlit": {}},
|
||||
"doubleSided": True,
|
||||
})
|
||||
extensions_used.append("KHR_materials_unlit")
|
||||
primitive["material"] = 0
|
||||
if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes:
|
||||
buffer_views.append({
|
||||
"buffer": 0,
|
||||
@ -271,6 +282,8 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
||||
gltf["textures"] = textures
|
||||
if materials:
|
||||
gltf["materials"] = materials
|
||||
if extensions_used:
|
||||
gltf["extensionsUsed"] = extensions_used
|
||||
|
||||
if metadata:
|
||||
gltf["asset"]["extras"] = metadata
|
||||
@ -376,7 +389,8 @@ class SaveGLB(IO.ComfyNode):
|
||||
save_glb(vertices_i, faces_i, os.path.join(full_output_folder, f), metadata,
|
||||
uvs=uvs_i,
|
||||
vertex_colors=v_colors,
|
||||
texture_image=tex_img)
|
||||
texture_image=tex_img,
|
||||
unlit=getattr(mesh, "unlit", False))
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
|
||||
270
comfy_extras/nodes_triposplat.py
Normal file
270
comfy_extras/nodes_triposplat.py
Normal file
@ -0,0 +1,270 @@
|
||||
# TripoSplat nodes: image -> 3D gaussian splat
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.nested_tensor
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
|
||||
|
||||
_Q_TOKEN_LENGTH = 8192
|
||||
_LATENT_CHANNELS = 16
|
||||
_CAM_CHANNELS = 5
|
||||
_DINOV3_MEAN = [0.485, 0.456, 0.406]
|
||||
_DINOV3_STD = [0.229, 0.224, 0.225]
|
||||
_NUM_GAUSSIANS_MIN = 32768
|
||||
_NUM_GAUSSIANS_MAX = 1048576
|
||||
|
||||
|
||||
def _preprocess(image: torch.Tensor, mask: torch.Tensor, erode_radius: int, size: int) -> torch.Tensor:
|
||||
# Match original preprocessing:
|
||||
# resize min side to `size` -> erode alpha -> alpha bbox -> 1.2x square crop -> resize -> composite on black.
|
||||
rgb = image[..., :3].clamp(0, 1).movedim(-1, 0) # (3, H, W)
|
||||
alpha = mask.clamp(0, 1)[None] # (1, H, W)
|
||||
rgba = torch.cat([rgb, alpha], 0)[None] # (1, 4, H, W)
|
||||
|
||||
h, w = rgba.shape[-2:]
|
||||
s = size / min(w, h)
|
||||
rgba = comfy.utils.common_upscale(rgba, max(1, round(w * s)), max(1, round(h * s)), "lanczos", "disabled").clamp(0, 1)
|
||||
|
||||
a = rgba[:, 3:4]
|
||||
if erode_radius > 0:
|
||||
# min filter over a (2r+1) window == morphological erosion of the alpha matte.
|
||||
a = -F.max_pool2d(-a, 2 * erode_radius + 1, stride=1, padding=erode_radius)
|
||||
rgba = torch.cat([rgba[:, :3], a], 1)
|
||||
|
||||
ys, xs = torch.nonzero(a[0, 0] > 0, as_tuple=True)
|
||||
if xs.numel() == 0:
|
||||
raise ValueError("TripoSplatPreprocessImage: mask is empty (no foreground pixels).")
|
||||
x0, x1 = int(xs.min()), int(xs.max())
|
||||
y0, y1 = int(ys.min()), int(ys.max())
|
||||
cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
|
||||
half = max(x1 - x0, y1 - y0) / 2 * 1.2
|
||||
left, upper, right, lower = int(cx - half), int(cy - half), int(cx + half), int(cy + half)
|
||||
|
||||
H, W = rgba.shape[-2:]
|
||||
crop = rgba.new_zeros((1, 4, lower - upper, right - left)) # out-of-bounds stays 0, matching PIL.crop
|
||||
sx0, sy0, sx1, sy1 = max(left, 0), max(upper, 0), min(right, W), min(lower, H)
|
||||
if sx1 > sx0 and sy1 > sy0:
|
||||
crop[:, :, sy0 - upper:sy1 - upper, sx0 - left:sx1 - left] = rgba[:, :, sy0:sy1, sx0:sx1]
|
||||
|
||||
crop = comfy.utils.common_upscale(crop, size, size, "lanczos", "disabled").clamp(0, 1)
|
||||
out = (crop[:, :3] * crop[:, 3:4])[0].movedim(0, -1) # composite over black == rgb * alpha
|
||||
return out.unsqueeze(0) # (1, 1024, 1024, 3)
|
||||
|
||||
|
||||
class TripoSplatPreprocessImage(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TripoSplatPreprocessImage",
|
||||
display_name="TripoSplat Preprocess Image",
|
||||
category="3d/conditioning",
|
||||
description="Crop center each image to a square canvas on a black background and add padding.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Mask.Input("mask"),
|
||||
IO.Int.Input("erode_radius", default=1, min=0, max=16,
|
||||
tooltip="Erode the alpha matte by this pixel radius before cropping (avoids border bleed)."),
|
||||
IO.Int.Input("size", default=1024, min=256, max=4096, step=16,
|
||||
tooltip="Square image size. The model is trained at 1024; other sizes run but are off-distribution."),
|
||||
],
|
||||
outputs=[IO.Image.Output(display_name="image")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, mask, erode_radius, size) -> IO.NodeOutput:
|
||||
size = max(16, (int(size) // 16) * 16) # DINOv3 patch / Flux2 VAE stride is 16
|
||||
if mask.shape[0] != image.shape[0]:
|
||||
mask = comfy.utils.repeat_to_batch_size(mask, image.shape[0])
|
||||
if tuple(mask.shape[1:]) != tuple(image.shape[1:3]):
|
||||
mask = F.interpolate(mask[:, None].float(), size=tuple(image.shape[1:3]), mode="bilinear", align_corners=False)[:, 0]
|
||||
prepared = torch.cat([_preprocess(image[i], mask[i], erode_radius, size) for i in range(image.shape[0])], dim=0)
|
||||
return IO.NodeOutput(prepared)
|
||||
|
||||
|
||||
class TripoSplatConditioning(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TripoSplatConditioning",
|
||||
display_name="TripoSplat Conditioning",
|
||||
category="3d/conditioning",
|
||||
description="Encode the image with DINOv3 and the Flux2 VAE into TripoSplat positive/negative "
|
||||
"conditioning, and create the fixed size noise target (latent + camera) for the KSampler",
|
||||
inputs=[
|
||||
IO.ClipVision.Input("clip_vision", tooltip="DINOv3 ViT-H/16+ image encoder"),
|
||||
IO.Vae.Input("vae", tooltip="Flux2 VAE"),
|
||||
IO.Image.Input("image"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
IO.Conditioning.Output(display_name="negative"),
|
||||
IO.Latent.Output(display_name="latent", tooltip="The fixed size noise target (latent +camera)."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip_vision, vae, image) -> IO.NodeOutput:
|
||||
# feature1: DINOv3 token sequence (cls + registers + patches), ImageNet-normalized, with a final non-affine layer norm on top
|
||||
comfy.model_management.load_model_gpu(clip_vision.patcher)
|
||||
device = clip_vision.load_device
|
||||
img = image.movedim(-1, 1).to(device) # (B,3,H,W) in [0,1]
|
||||
mean = torch.tensor(_DINOV3_MEAN, device=device).view(1, 3, 1, 1)
|
||||
std = torch.tensor(_DINOV3_STD, device=device).view(1, 3, 1, 1)
|
||||
img = (img - mean) / std
|
||||
seq = clip_vision.model(pixel_values=img.float())[0]
|
||||
feature1 = F.layer_norm(seq.float(), seq.shape[-1:]).to(comfy.model_management.intermediate_device())
|
||||
|
||||
# Second conditioning: the Flux2 VAE latent of the image, carried as a standard reference_latents entry
|
||||
ref = vae.encode(image).to(comfy.model_management.intermediate_device()) # (B, 128, H, W)
|
||||
b = ref.shape[0]
|
||||
|
||||
positive = [[feature1, {"reference_latents": [ref]}]]
|
||||
negative = [[torch.zeros_like(feature1), {"reference_latents": [torch.zeros_like(ref)]}]]
|
||||
|
||||
# Fixed noise target: the latent is a constant-shape (8192, 16) shape-code + a (1, 5) camera token
|
||||
dev = comfy.model_management.intermediate_device()
|
||||
latent_seq = torch.zeros([b, _Q_TOKEN_LENGTH, _LATENT_CHANNELS], device=dev)
|
||||
camera = torch.zeros([b, 1, _CAM_CHANNELS], device=dev)
|
||||
samples = comfy.nested_tensor.NestedTensor((latent_seq, camera))
|
||||
return IO.NodeOutput(positive, negative, {"samples": samples})
|
||||
|
||||
|
||||
class VAEDecodeTripoSplat(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeTripoSplat",
|
||||
display_name="TripoSplat Decode",
|
||||
category="3d/latent",
|
||||
description="Decode the sampled TripoSplat latent into a 3D gaussian splat. "
|
||||
"Modify the number of gaussians to vary the density.",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"),
|
||||
IO.Int.Input("num_gaussians", default=262144, min=_NUM_GAUSSIANS_MIN, max=_NUM_GAUSSIANS_MAX, step=32,
|
||||
tooltip="Number of gaussians to produce (rounded to a multiple of 32). "
|
||||
"262144 matches the octree's point density; higher oversamples the same points "
|
||||
"(denser, but no new detail) and costs proportionally more VRAM/time."),
|
||||
IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff,
|
||||
tooltip="Seeds the octree point sampler (global RNG) for deterministic decodes."),
|
||||
],
|
||||
outputs=[IO.Splat.Output(display_name="splat")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae, num_gaussians, seed) -> IO.NodeOutput:
|
||||
s = samples["samples"]
|
||||
latent = s.unbind()[0] if getattr(s, "is_nested", False) else s # take the latent stream, drop camera
|
||||
|
||||
decoder = vae.first_stage_model
|
||||
gpp = decoder.gaussians_per_point
|
||||
n = max(_NUM_GAUSSIANS_MIN, min(_NUM_GAUSSIANS_MAX, int(num_gaussians)))
|
||||
if n % gpp != 0:
|
||||
n = round(n / gpp) * gpp
|
||||
|
||||
dtype_size = comfy.model_management.dtype_size(vae.vae_dtype)
|
||||
hidden = decoder.gs.model_channels
|
||||
cond_tokens = latent.shape[1]
|
||||
memory_required = (cond_tokens * 4 + (n // gpp) * 10) * hidden * dtype_size
|
||||
comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required)
|
||||
latent = latent.to(device=vae.device, dtype=vae.vae_dtype)
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
parts = [g.render_tensors() for g in decoder.decode(latent, num_gaussians=n, generator=generator)]
|
||||
positions, scales, rotations, opacities, sh = (torch.stack(t) for t in zip(*parts))
|
||||
return IO.NodeOutput(Types.SPLAT(positions, scales, rotations, opacities, sh))
|
||||
|
||||
|
||||
class TripoSplatSamplingPreview(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TripoSplatSamplingPreview",
|
||||
display_name="TripoSplat Sampling Preview",
|
||||
category="3d/latent",
|
||||
description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded "
|
||||
"gaussian splat preview at each step.",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"),
|
||||
IO.Int.Input("octree_level", default=5, min=2, max=8, advanced=True,
|
||||
tooltip="Octree depth for the preview decode (lower = cheaper/coarser)."),
|
||||
IO.Int.Input("num_gaussians", default=16384, min=1024, max=262144, step=32,
|
||||
tooltip="Number of gaussians to produce for the preview (rounded to a multiple of 32)."),
|
||||
IO.Float.Input("yaw", default=90.0, min=-360.0, max=360.0, step=1.0, tooltip="Preview camera yaw in degrees.", advanced=True,),
|
||||
IO.Float.Input("pitch", default=15.0, min=-89.0, max=89.0, step=1.0, tooltip="Preview camera pitch in degrees.", advanced=True,),
|
||||
IO.Int.Input("point_size", default=3, min=1, max=16,
|
||||
tooltip="Maximum splat radius in pixels. Each gaussian is sized from its scale and capped here; "
|
||||
"lower = finer/pointier, higher = chunkier."),
|
||||
],
|
||||
outputs=[IO.Model.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, vae, octree_level, num_gaussians, yaw, pitch, point_size) -> IO.NodeOutput:
|
||||
from comfy.ldm.triposplat.preview import decode_x0_to_image
|
||||
cfg = {"gaussians": num_gaussians, "level": octree_level, "yaw": yaw, "pitch": pitch,
|
||||
"point_size": point_size}
|
||||
|
||||
fsm = vae.first_stage_model
|
||||
cond_tokens = model.model.diffusion_model.q_token_length
|
||||
memory_required = (cond_tokens * 4 + (num_gaussians // fsm.gaussians_per_point) * 10) * fsm.gs.model_channels * comfy.model_management.dtype_size(vae.vae_dtype)
|
||||
|
||||
# Live preview via WrappersMP.OUTER_SAMPLE + ProgressBar
|
||||
# The wrapper augments the sampler's own callback to decode x0 -> gaussian splat -> preview image each step
|
||||
def outer_sample_wrapper(executor, *args, **kwargs):
|
||||
args = list(args)
|
||||
cb_idx = 5 # outer_sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
orig_cb = args[cb_idx] if len(args) > cb_idx else kwargs.get("callback")
|
||||
state = {"ok": True, "pbar": None, "loaded": False}
|
||||
|
||||
def callback(step, x0, x, total_steps):
|
||||
if orig_cb is not None:
|
||||
orig_cb(step, x0, x, total_steps)
|
||||
if not state["ok"]:
|
||||
return
|
||||
try:
|
||||
if not state["loaded"]:
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
loaded_models.append(vae.patcher)
|
||||
comfy.model_management.load_models_gpu(loaded_models, memory_required=memory_required)
|
||||
state["loaded"] = True
|
||||
img = decode_x0_to_image(vae, x0, cfg)
|
||||
if state["pbar"] is None:
|
||||
state["pbar"] = comfy.utils.ProgressBar(total_steps)
|
||||
state["pbar"].update_absolute(step + 1, total_steps, ("JPEG", img, 512))
|
||||
except Exception as e:
|
||||
logging.warning("TripoSplatSamplingPreview: preview failed, disabling ({})".format(e))
|
||||
state["ok"] = False
|
||||
|
||||
if len(args) > cb_idx:
|
||||
args[cb_idx] = callback
|
||||
else:
|
||||
kwargs["callback"] = callback
|
||||
return executor(*args, **kwargs)
|
||||
|
||||
m = model.clone()
|
||||
m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "triposplat_sampling_preview", outer_sample_wrapper)
|
||||
return IO.NodeOutput(m)
|
||||
|
||||
|
||||
class TripoSplatExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
TripoSplatPreprocessImage,
|
||||
TripoSplatConditioning,
|
||||
VAEDecodeTripoSplat,
|
||||
TripoSplatSamplingPreview,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> TripoSplatExtension:
|
||||
return TripoSplatExtension()
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.22.0"
|
||||
__version__ = "0.23.0"
|
||||
|
||||
7
main.py
7
main.py
@ -464,13 +464,6 @@ def start_comfyui(asyncio_loop=None):
|
||||
folder_paths.set_temp_directory(temp_dir)
|
||||
cleanup_temp()
|
||||
|
||||
if args.windows_standalone_build:
|
||||
try:
|
||||
import new_updater
|
||||
new_updater.update_windows_updater()
|
||||
except:
|
||||
pass
|
||||
|
||||
if not asyncio_loop:
|
||||
asyncio_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
|
||||
@ -1,35 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
base_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
||||
def update_windows_updater():
|
||||
top_path = os.path.dirname(base_path)
|
||||
updater_path = os.path.join(base_path, ".ci/update_windows/update.py")
|
||||
bat_path = os.path.join(base_path, ".ci/update_windows/update_comfyui.bat")
|
||||
|
||||
dest_updater_path = os.path.join(top_path, "update/update.py")
|
||||
dest_bat_path = os.path.join(top_path, "update/update_comfyui.bat")
|
||||
dest_bat_deps_path = os.path.join(top_path, "update/update_comfyui_and_python_dependencies.bat")
|
||||
|
||||
try:
|
||||
with open(dest_bat_path, 'rb') as f:
|
||||
contents = f.read()
|
||||
except:
|
||||
return
|
||||
|
||||
if not contents.startswith(b"..\\python_embeded\\python.exe .\\update.py"):
|
||||
return
|
||||
|
||||
shutil.copy(updater_path, dest_updater_path)
|
||||
try:
|
||||
with open(dest_bat_deps_path, 'rb') as f:
|
||||
contents = f.read()
|
||||
contents = contents.replace(b'..\\python_embeded\\python.exe .\\update.py ..\\ComfyUI\\', b'call update_comfyui.bat nopause')
|
||||
with open(dest_bat_deps_path, 'wb') as f:
|
||||
f.write(contents)
|
||||
except:
|
||||
pass
|
||||
shutil.copy(bat_path, dest_bat_path)
|
||||
print("Updated the windows standalone package updater.") # noqa: T201
|
||||
2
nodes.py
2
nodes.py
@ -2455,6 +2455,8 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_save_3d.py",
|
||||
"nodes_moge.py",
|
||||
"nodes_mediapipe.py",
|
||||
"nodes_gaussian_splat.py",
|
||||
"nodes_triposplat.py"
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.22.0"
|
||||
version = "0.23.0"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.44.19
|
||||
comfyui-workflow-templates==0.9.91
|
||||
comfyui-workflow-templates==0.9.92
|
||||
comfyui-embedded-docs==0.5.2
|
||||
torch
|
||||
torchsde
|
||||
|
||||
@ -197,3 +197,10 @@ class TestMathExpressionExecute:
|
||||
def test_pow_huge_exponent_raises(self):
|
||||
with pytest.raises(ValueError, match="Exponent .* exceeds maximum"):
|
||||
self._exec("pow(a, b)", a=10, b=10000000)
|
||||
|
||||
def test_huge_int_result_raises_value_error(self):
|
||||
# Exponent is within the allowed MAX_EXPONENT range, so the result is a
|
||||
# finite Python int that is nonetheless too large to convert to float.
|
||||
# This must raise a clean ValueError, not an uncaught OverflowError.
|
||||
with pytest.raises(ValueError, match="too large to represent as a float"):
|
||||
self._exec("2 ** 3999")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user