diff --git a/comfy/background_removal/birefnet.py b/comfy/background_removal/birefnet.py index df54b2b90..78a80246e 100644 --- a/comfy/background_removal/birefnet.py +++ b/comfy/background_removal/birefnet.py @@ -105,7 +105,7 @@ class WindowAttention(nn.Module): relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = comfy.ops.cast_to_input(relative_position_bias.permute(2, 0, 1).contiguous(), attn) # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 1691fca81..ce8924a11 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -9,6 +9,7 @@ import comfy.model_management import comfy.utils import comfy.clip_model import comfy.image_encoders.dino2 +import comfy.image_encoders.dino3 class Output: def __getitem__(self, key): @@ -23,12 +24,16 @@ IMAGE_ENCODERS = { "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection, "dinov2": comfy.image_encoders.dino2.Dinov2Model, + "dinov3": comfy.image_encoders.dino3.DINOv3ViTModel, } class ClipVisionModel(): def __init__(self, json_config): - with open(json_config) as f: - config = json.load(f) + if isinstance(json_config, dict): + config = json_config + else: + with open(json_config) as f: + config = json.load(f) self.image_size = config.get("image_size", 224) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) @@ -134,6 +139,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json") elif 'encoder.layer.23.layer_scale2.lambda1' in sd: json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json") + elif 'layer.0.mlp.gate_proj.weight' in sd and 'layer.31.norm1.weight' in sd: # Dinov3 ViT-H/16+ (SwiGLU gated MLP, 32 layers) + json_config = comfy.image_encoders.dino3.DINOV3_VITH_CONFIG else: return None diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py new file mode 100644 index 000000000..ad29b06f8 --- /dev/null +++ b/comfy/image_encoders/dino3.py @@ -0,0 +1,259 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops +from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale + + +# DINOv3 ViT-H/16+ (SwiGLU) +DINOV3_VITH_CONFIG = { + "model_type": "dinov3", + "num_hidden_layers": 32, + "hidden_size": 1280, + "num_attention_heads": 20, + "num_register_tokens": 4, + "intermediate_size": 5120, + "layer_norm_eps": 1e-5, + "num_channels": 3, + "patch_size": 16, + "rope_theta": 100.0, + "use_gated_mlp": True, + "gated_mlp_act": "silu", + "image_size": 1024, + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225], +} + + +class DINOv3ViTMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype) + self.act_fn = torch.nn.GELU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, **kwargs): + num_tokens = q.shape[-2] + num_patches = sin.shape[-2] + num_prefix_tokens = num_tokens - num_patches + + q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2) + k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2) + + q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin) + k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin) + + q = torch.cat((q_prefix_tokens, q_patches), dim=-2) + k = torch.cat((k_prefix_tokens, k_patches), dim=-2) + + return q, k + + +class DINOv3ViTAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, device, dtype, operations): + super().__init__() + self.embed_dim = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False + self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None, **kwargs): + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attn = optimized_attention_for_device(query_states.device, mask=False) + attn_output = attn( + query_states, key_states, value_states, self.num_heads, attention_mask, + skip_reshape=True, skip_output_reshape=True, low_precision_attention=False, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + +class DINOv3ViTGatedMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations, act="silu"): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype) + self.act_fn = torch.nn.SiLU() if act == "silu" else torch.nn.GELU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +def get_patches_center_coordinates(num_patches_h, num_patches_w, dtype, device): + coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device) + coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device) + coords_h = coords_h / num_patches_h + coords_w = coords_w / num_patches_w + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) + coords = coords.flatten(0, 1) + coords = 2.0 * coords - 1.0 + return coords + + +class DINOv3ViTRopePositionEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, rope_theta, hidden_size, num_attention_heads, patch_size, device, dtype): + super().__init__() + self.base = rope_theta + self.head_dim = hidden_size // num_attention_heads + self.patch_size = patch_size + + inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, pixel_values): + _, _, height, width = pixel_values.shape + num_patches_h = height // self.patch_size + num_patches_w = width // self.patch_size + + patch_coords = get_patches_center_coordinates(num_patches_h, num_patches_w, dtype=torch.float32, device=pixel_values.device) + self.inv_freq = self.inv_freq.to(pixel_values.device) + angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] + angles = angles.flatten(1, 2) + angles = angles.tile(2) + cos = torch.cos(angles).to(dtype=pixel_values.dtype) + sin = torch.sin(angles).to(dtype=pixel_values.dtype) + return cos, sin + + +class DINOv3ViTEmbeddings(nn.Module): + def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations): + super().__init__() + self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) + self.mask_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) + self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype)) + self.patch_embeddings = operations.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype + ) + + def forward(self, pixel_values, bool_masked_pos=None): + batch_size = pixel_values.shape[0] + + patch_embeddings = self.patch_embeddings(pixel_values) + patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) + + if bool_masked_pos is not None: + mask_token = comfy.ops.cast_to_input(self.mask_token, patch_embeddings) + patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings) + + cls_token = comfy.ops.cast_to_input(self.cls_token.expand(batch_size, -1, -1), patch_embeddings) + register_tokens = comfy.ops.cast_to_input(self.register_tokens.expand(batch_size, -1, -1), patch_embeddings) + embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) + return embeddings + + +class DINOv3ViTLayer(nn.Module): + def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size, + num_attention_heads, device, dtype, operations, gated_mlp_act="silu"): + super().__init__() + self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) + self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations) + self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) + + self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) + if use_gated_mlp: + self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations, act=gated_mlp_act) + else: + self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations) + self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = self.attention(hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings) + hidden_states = self.layer_scale1(hidden_states) + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.layer_scale2(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class DINOv3ViTModel(nn.Module): + def __init__(self, config, dtype, device, operations): + super().__init__() + num_hidden_layers = config["num_hidden_layers"] + hidden_size = config["hidden_size"] + num_attention_heads = config["num_attention_heads"] + num_register_tokens = config["num_register_tokens"] + intermediate_size = config["intermediate_size"] + layer_norm_eps = config["layer_norm_eps"] + num_channels = config["num_channels"] + patch_size = config["patch_size"] + rope_theta = config["rope_theta"] + use_gated_mlp = config.get("use_gated_mlp", False) + gated_mlp_act = config.get("gated_mlp_act", "silu") + + self.embeddings = DINOv3ViTEmbeddings( + hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size, + dtype=dtype, device=device, operations=operations + ) + self.rope_embeddings = DINOv3ViTRopePositionEmbedding( + rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device + ) + self.layer = nn.ModuleList([ + DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=use_gated_mlp, mlp_bias=True, + intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, + dtype=dtype, device=device, operations=operations, gated_mlp_act=gated_mlp_act) + for _ in range(num_hidden_layers)]) + self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device) + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward(self, pixel_values, bool_masked_pos=None, **kwargs): + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + position_embeddings = self.rope_embeddings(pixel_values) + + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, position_embeddings=position_embeddings) + + if kwargs.get("skip_norm_elementwise", False): + sequence_output = F.layer_norm(hidden_states, hidden_states.shape[-1:]) + else: + norm = self.norm.to(hidden_states.device) + sequence_output = norm(hidden_states) + pooled_output = sequence_output[:, 0, :] + return sequence_output, None, pooled_output, None diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 12a934d71..bbdfd4bc2 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -239,6 +239,16 @@ class Flux2(LatentFormat): def process_out(self, latent): return latent +class TripoSplat(LatentFormat): + # Sequence latent (B, 8192, 16) the camera token rides alongside as a second nested latent + latent_channels = 16 + + def process_in(self, latent): + return latent + + def process_out(self, latent): + return latent + class Mochi(LatentFormat): latent_channels = 12 latent_dimensions = 3 diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index 4fb56165e..86af98d36 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -38,6 +38,8 @@ class ChromaRadianceParams(ChromaParams): # None means use the same dtype as the model. nerf_embedder_dtype: Optional[torch.dtype] use_x0: bool + # Use sequential txt_ids instead of zeros + use_sequential_txt_ids: bool class ChromaRadiance(Chroma): """ @@ -162,6 +164,9 @@ class ChromaRadiance(Chroma): if params.use_x0: self.register_buffer("__x0__", torch.tensor([])) + if params.use_sequential_txt_ids: + self.register_buffer("__sequential__", torch.tensor([])) + @property def _nerf_final_layer(self) -> nn.Module: if self.params.nerf_final_head_type == "linear": @@ -313,6 +318,9 @@ class ChromaRadiance(Chroma): img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + # Radiance after 2026-05-22 uses sequential txt_ids instead of zeros + if params.use_sequential_txt_ids: + txt_ids[:, :, 0] = torch.arange(context.shape[1], device=x.device, dtype=x.dtype).unsqueeze(0).expand(bs, -1) img_out = self.forward_orig( img, diff --git a/comfy/ldm/triposplat/gaussian.py b/comfy/ldm/triposplat/gaussian.py new file mode 100644 index 000000000..a4cd2f62f --- /dev/null +++ b/comfy/ldm/triposplat/gaussian.py @@ -0,0 +1,199 @@ +# TripoSplat 3D gaussian container. Operates on already-decoded +# tensors and exposes them as render-ready tensors (render_tensors) for the generic SPLAT type. +import torch +import torch.nn.functional as F + +import comfy.model_management + + +class GaussianModel: + def __init__(self, aabb: list, sh_degree: int = 0, mininum_kernel_size: float = 0.0, + scaling_bias: float = 0.01, opacity_bias: float = 0.1, + scaling_activation: str = "exp", device=None): + self.sh_degree = sh_degree + self.mininum_kernel_size = mininum_kernel_size + self.scaling_bias = scaling_bias + self.opacity_bias = opacity_bias + self.device = device + self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + + if scaling_activation == "exp": + self._scaling_activation = torch.exp + self._inverse_scaling_activation = torch.log + elif scaling_activation == "softplus": + self._scaling_activation = F.softplus + self._inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x)) + + self._opacity_activation = torch.sigmoid + self._inverse_opacity_activation = lambda x: torch.log(x / (1 - x)) + + self.scale_bias = self._inverse_scaling_activation(torch.tensor(self.scaling_bias)).to(self.device) + self.rots_bias = torch.zeros(4, device=self.device) + self.rots_bias[0] = 1 + self.opacity_bias_val = self._inverse_opacity_activation(torch.tensor(self.opacity_bias)).to(self.device) + + self._storage = {} + + def _get_store(self, name): + return self._storage.get(name) + + def _set_store(self, name, value): + self._storage[name] = value + + @property + def _xyz(self): + return self._get_store("_xyz") + @_xyz.setter + def _xyz(self, value): + if value is None: + self._set_store("_xyz", None) + self._set_store("xyz", None) + return + self._set_store("_xyz", value) + self._set_store("xyz", value * self.aabb[None, 3:] + self.aabb[None, :3]) + + @property + def get_xyz(self): + return self._get_store("xyz") + + @property + def _features_dc(self): + return self._get_store("_features_dc") + @_features_dc.setter + def _features_dc(self, value): + self._set_store("_features_dc", value) + + @property + def _opacity(self): + return self._get_store("_opacity") + @_opacity.setter + def _opacity(self, value): + if value is None: + self._set_store("_opacity", None) + self._set_store("opacity", None) + return + self._set_store("_opacity", value) + self._set_store("opacity", self._opacity_activation(value + self.opacity_bias_val)) + + @property + def get_opacity(self): + return self._get_store("opacity") + + @property + def _scaling(self): + return self._get_store("_scaling") + @_scaling.setter + def _scaling(self, value): + if value is None: + self._set_store("_scaling", None) + self._set_store("scaling", None) + return + self._set_store("_scaling", value) + s = self._scaling_activation(value + self.scale_bias) + s = torch.square(s) + self.mininum_kernel_size ** 2 + self._set_store("scaling", torch.sqrt(s)) + + @property + def get_scaling(self): + return self._get_store("scaling") + + @property + def _rotation(self): + return self._get_store("_rotation") + @_rotation.setter + def _rotation(self, value): + self._set_store("_rotation", value) + + _DEFAULT_TRANSFORM = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] + + def render_tensors(self): + # Render-ready (activated, world-space) tensors for the generic SPLAT type. The axis transform + # (a 3x3 rotation, object frame -> viewer Y-up) is baked into positions and rotations. + # Returns float tensors on the intermediate device: positions (N,3), scales (N,3) linear, + # rotations (N,4) wxyz, opacities (N,1) in [0,1], sh (N,K,3) coefficients. + xyz = self.get_xyz.float() + scaling = self.get_scaling.float() + opacity = self.get_opacity.float() + rotation = (self._rotation + self.rots_bias[None, :]).float() + sh = self._features_dc.float() # (N, K, 3) + T = torch.as_tensor(self._DEFAULT_TRANSFORM, dtype=torch.float32, device=xyz.device) + xyz = xyz @ T.T + rotation = _matrix_to_quat(torch.matmul(T, _quat_to_matrix(rotation))) + rotation = rotation / torch.linalg.norm(rotation, dim=-1, keepdim=True) + out_device = comfy.model_management.intermediate_device() + return ( + xyz.to(out_device).contiguous(), scaling.to(out_device).contiguous(), + rotation.to(out_device).contiguous(), opacity.to(out_device).contiguous(), + sh.to(out_device).contiguous(), + ) + + +def _quat_to_matrix(q): + q = q / torch.linalg.norm(q, dim=-1, keepdim=True) + w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3] + R = torch.stack([ + 1 - 2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y), + 2*(x*y + w*z), 1 - 2*(x*x + z*z), 2*(y*z - w*x), + 2*(x*z - w*y), 2*(y*z + w*x), 1 - 2*(x*x + y*y), + ], dim=-1).reshape(-1, 3, 3) + return R + + +def _matrix_to_quat(R): + trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2] + q = torch.zeros((R.shape[0], 4), dtype=R.dtype, device=R.device) + s = torch.sqrt(torch.clamp(trace + 1, min=0)) * 2 + q[:, 0] = 0.25 * s + denom = torch.where(s != 0, s, torch.ones_like(s)) + q[:, 1] = (R[:, 2, 1] - R[:, 1, 2]) / denom + q[:, 2] = (R[:, 0, 2] - R[:, 2, 0]) / denom + q[:, 3] = (R[:, 1, 0] - R[:, 0, 1]) / denom + m01 = (R[:, 0, 0] >= R[:, 1, 1]) & (R[:, 0, 0] >= R[:, 2, 2]) & (s == 0) + s1 = torch.sqrt(torch.clamp(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2], min=0)) * 2 + q[m01, 0] = (R[m01, 2, 1] - R[m01, 1, 2]) / s1[m01] + q[m01, 1] = 0.25 * s1[m01] + q[m01, 2] = (R[m01, 0, 1] + R[m01, 1, 0]) / s1[m01] + q[m01, 3] = (R[m01, 0, 2] + R[m01, 2, 0]) / s1[m01] + m11 = (R[:, 1, 1] > R[:, 0, 0]) & (R[:, 1, 1] >= R[:, 2, 2]) & (s == 0) + s2 = torch.sqrt(torch.clamp(1 + R[:, 1, 1] - R[:, 0, 0] - R[:, 2, 2], min=0)) * 2 + q[m11, 0] = (R[m11, 0, 2] - R[m11, 2, 0]) / s2[m11] + q[m11, 1] = (R[m11, 0, 1] + R[m11, 1, 0]) / s2[m11] + q[m11, 2] = 0.25 * s2[m11] + q[m11, 3] = (R[m11, 1, 2] + R[m11, 2, 1]) / s2[m11] + m21 = (R[:, 2, 2] > R[:, 0, 0]) & (R[:, 2, 2] > R[:, 1, 1]) & (s == 0) + s3 = torch.sqrt(torch.clamp(1 + R[:, 2, 2] - R[:, 0, 0] - R[:, 1, 1], min=0)) * 2 + q[m21, 0] = (R[m21, 1, 0] - R[m21, 0, 1]) / s3[m21] + q[m21, 1] = (R[m21, 0, 2] + R[m21, 2, 0]) / s3[m21] + q[m21, 2] = (R[m21, 1, 2] + R[m21, 2, 1]) / s3[m21] + q[m21, 3] = 0.25 * s3[m21] + return q / torch.linalg.norm(q, dim=-1, keepdim=True) + + +def build_gaussian_models(decoder, points_pred: dict, pred: dict): + # Assemble GaussianModels from the elastic decoder layout. decoder is the ElasticGaussianFixedlenDecoder + # (carries layout / rep_config / _get_offset) + x = points_pred + offset = decoder._get_offset(pred['features']) + h = pred["features"] + ret = [] + for i in range(h.shape[0]): + g = GaussianModel( + sh_degree=0, + aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0], + mininum_kernel_size=decoder.rep_config['filter_kernel_size_3d'], + scaling_bias=decoder.rep_config['scaling_bias'], + opacity_bias=decoder.rep_config['opacity_bias'], + scaling_activation=decoder.rep_config['scaling_activation'], + device=h.device, + ) + _x = x["points"][i, :, None, :] + for k, v in decoder.layout.items(): + if k == '_xyz': + setattr(g, k, (offset[i] + _x).flatten(0, 1)) + elif k in ('_xyz_center', '_offset_scale'): + continue + else: + feats = h[i][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1) + setattr(g, k, feats * decoder.rep_config['lr'][k]) + ret.append(g) + return ret diff --git a/comfy/ldm/triposplat/model.py b/comfy/ldm/triposplat/model.py new file mode 100644 index 000000000..d8a531772 --- /dev/null +++ b/comfy/ldm/triposplat/model.py @@ -0,0 +1,326 @@ +# TripoSplat flow-matching denoiser (LatentSeqMMFlowModel). Registered as a ModelType.FLOW arch and +# driven by the standard KSampler; jointly denoises the (B, 8192, 16) latent and a (B, 1, 5) camera token +# carried as a 2-element nested latent. +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.model_management +import comfy.patcher_extension +import comfy.rmsnorm +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope + + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim, heads, dtype=None, device=None): + super().__init__() + self.gamma = nn.Parameter(torch.empty(heads, dim, dtype=dtype, device=device)) + + def forward(self, x): + x = comfy.rmsnorm.rms_norm(x) + return x * comfy.model_management.cast_to(self.gamma, x.dtype, x.device) + + +# Positional embeddings + +class RePo3DRotaryEmbedding(nn.Module): + def __init__(self, model_channels, num_heads, head_dim, repo_hidden_ratio=0.125, max_freq=16.0, + dtype=None, device=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + repo_hidden_size = int(model_channels * repo_hidden_ratio) + self.norm = operations.LayerNorm(model_channels, dtype=dtype, device=device) + self.gate_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device) + self.content_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device) + self.act = nn.SiLU() + self.final_map = operations.Linear(repo_hidden_size, 3 * num_heads, bias=False, dtype=dtype, device=device) + self.dim_0 = 2 * (head_dim // 6) + self.dim_1 = 2 * (head_dim // 6) + self.dim_2 = head_dim - self.dim_0 - self.dim_1 + dims = [self.dim_0, self.dim_1, self.dim_2] + freqs_list = [] + for d in dims: + freq_dim = d // 2 + freqs_list.append(torch.linspace(1.0, float(max_freq), steps=freq_dim, dtype=torch.float32)) + self.freqs_0 = nn.Parameter(freqs_list[0]) + self.freqs_1 = nn.Parameter(freqs_list[1]) + self.freqs_2 = nn.Parameter(freqs_list[2]) + + def forward(self, hidden_states): + h = self.norm(hidden_states) + feat = self.act(self.gate_map(h)) * self.content_map(h) + out = self.final_map(feat) + B, L, _ = out.shape + delta_pos = out.reshape(B, L, self.num_heads, 3) + f0 = comfy.model_management.cast_to(self.freqs_0, torch.float32, out.device) + f1 = comfy.model_management.cast_to(self.freqs_1, torch.float32, out.device) + f2 = comfy.model_management.cast_to(self.freqs_2, torch.float32, out.device) + ang_0 = delta_pos[..., 0].unsqueeze(-1) * f0 * torch.pi + ang_1 = delta_pos[..., 1].unsqueeze(-1) * f1 * torch.pi + ang_2 = delta_pos[..., 2].unsqueeze(-1) * f2 * torch.pi + ang = torch.cat([ang_0, ang_1, ang_2], dim=-1).float() # (B, L, heads, head_dim/2) + cos, sin = ang.cos(), ang.sin() + return torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*ang.shape, 2, 2) + + +class PcdAbsolutePositionEmbedder(nn.Module): + # Sinusoidal absolute position embedding. Two fixed schedules are used in TripoSplat: + # "pow2" (flow-model latent anchors) and "log2" (octree / gaussian decoders). + def __init__(self, channels: int, in_channels: int = 3, max_res: int = 16, schedule: str = "pow2"): + super().__init__() + self.channels = channels + self.in_channels = in_channels + self.max_res = max_res + self.schedule = schedule + self.freq_dim = channels // in_channels // 2 + + def _freqs(self, device): + if self.schedule == "pow2": + freqs_2exp = torch.arange(self.max_res, dtype=torch.float32, device=device) + res_dim = max(0, self.freq_dim - self.max_res) + freqs_res = (torch.arange(res_dim, dtype=torch.float32, device=device) / max(res_dim, 1) * self.max_res + if res_dim > 0 else torch.empty(0, device=device)) + freqs = torch.cat([freqs_2exp, freqs_res], dim=0)[:self.freq_dim] + return torch.pow(2.0, freqs) * 2.0 # *2 folds this schedule's 2*pi into the shared *pi below + logs = torch.linspace(0.0, float(self.max_res), steps=self.freq_dim, dtype=torch.float32, device=device) + return torch.pow(2.0, logs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + x = x.float() + *dims, D = x.shape + out = torch.outer(x.reshape(-1), self._freqs(x.device)) * torch.pi + out = torch.cat([out.sin(), out.cos()], dim=-1).reshape(*dims, -1) + if out.shape[-1] < self.channels: + out = torch.cat([out, torch.zeros(*dims, self.channels - out.shape[-1], + device=out.device, dtype=out.dtype)], dim=-1) + return out.to(orig_dtype) + + +def attention(q, k, v, transformer_options=None): + # q, k, v: (B, L, heads, dim) -> (B, L, heads, dim). Shared optimized_attention call convention. + out = optimized_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), heads=q.shape[2], + skip_reshape=True, skip_output_reshape=True, low_precision_attention=False, + transformer_options=transformer_options) + return out.transpose(1, 2) + + +# Transformer building blocks + +class MLP(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.mlp = nn.Sequential( + operations.Linear(in_channels, hidden_channels, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(hidden_channels, out_channels, dtype=dtype, device=device), + ) + + def forward(self, x): + return self.mlp(x) + + +class RopeMultiHeadAttention(nn.Module): + def __init__(self, channels, num_heads, qkv_bias=True, qk_rms_norm=False, use_rope=False, + dtype=None, device=None, operations=None): + super().__init__() + self.channels = channels + self.num_heads = num_heads + self.head_dim = channels // num_heads + self.qk_rms_norm = qk_rms_norm + self.use_rope = use_rope + self.qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device) + if self.qk_rms_norm: + self.q_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device) + self.k_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device) + self.out = operations.Linear(channels, channels, dtype=dtype, device=device) + + def forward(self, x, rope_emb=None, transformer_options=None): + B, L, C = x.shape + qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(2) + if self.use_rope: + q, k = apply_rope(q, k, rope_emb) + if self.qk_rms_norm: + q = self.q_norm(q) + k = self.k_norm(k) + h = attention(q, k, v, transformer_options) # (B, L, heads, dim) + return self.out(h.reshape(B, L, C)) + + +class UnifiedTransformerBlock(nn.Module): + def __init__(self, channels, num_heads, mlp_ratio=4.0, + use_rope=False, qk_rms_norm=False, qkv_bias=True, + modulation=True, share_mod=False, + dtype=None, device=None, operations=None): + super().__init__() + self.modulation = modulation + self.share_mod = share_mod + self.norm1 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device) + self.attn = RopeMultiHeadAttention(channels, num_heads=num_heads, + qkv_bias=qkv_bias, use_rope=use_rope, qk_rms_norm=qk_rms_norm, + dtype=dtype, device=device, operations=operations) + self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations) + if modulation: + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)) + self.shift_table = nn.Parameter(torch.empty(1, 6 * channels, dtype=dtype, device=device)) + + def forward(self, x, mod=None, rotary_emb=None, transformer_options=None): + if self.modulation: + if not self.share_mod: + mod = self.adaLN_modulation(mod) + mod = mod + comfy.model_management.cast_to(self.shift_table, mod.dtype, mod.device) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1)) + x = torch.addcmul(x, self.attn(h, rope_emb=rotary_emb, transformer_options=transformer_options), gate_msa.unsqueeze(1)) + h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1)) + x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1)) + else: + x = x + self.attn(self.norm1(x), rope_emb=rotary_emb, transformer_options=transformer_options) + x = x + self.mlp(self.norm2(x)) + return x + + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + super().__init__() + self.mlp = nn.Sequential( + operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + emb = self.timestep_embedding(t, self.frequency_embedding_size) + return self.mlp(emb.to(self.mlp[0].weight.dtype)) + + +class LatentSeqMMFlowModel(nn.Module): + def __init__(self, image_model=None, q_token_length=8192, in_channels=16, model_channels=1024, + cond_channels=1280, out_channels=16, num_blocks=24, num_refiner_blocks=2, + num_heads=None, num_head_channels=64, cam_channels=5, cond2_channels=128, + mlp_ratio=4, share_mod=True, qk_rms_norm=True, + dtype=None, device=None, operations=None, **kwargs): + super().__init__() + self.dtype = dtype + self.q_token_length = q_token_length + self.in_channels = in_channels + self.cam_channels = cam_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.cond2_channels = cond2_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_refiner_blocks = num_refiner_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + + factory_kwargs = dict(dtype=dtype, device=device) + op_kwargs = dict(operations=operations, **factory_kwargs) + + self.t_embedder = TimestepEmbedder(model_channels, **op_kwargs) + if share_mod: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, **factory_kwargs)) + + self.input_layer = operations.Linear(in_channels, model_channels, **factory_kwargs) + self.cond_embedder = operations.Linear(cond_channels, model_channels, **factory_kwargs) + self.cond_embedder2 = operations.Linear(cond2_channels, model_channels, **factory_kwargs) if cond2_channels is not None else None + + # Fixed Sobol (low-discrepancy) 3D anchor positions for the latent tokens, used as positional encoding. + # The embedder is parameter-free and the anchors are fixed, precompute once. + sobol_seq = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123).draw(q_token_length) + pos_emb = PcdAbsolutePositionEmbedder(model_channels)(sobol_seq.unsqueeze(0)) + self.register_buffer("pos_emb", pos_emb, persistent=False) + + # RePo3DRotaryEmbedding layers for the refiner and main blocks + repo_kwargs = dict(num_heads=self.num_heads, head_dim=num_head_channels, **op_kwargs) + self.noise_repo_layers = nn.ModuleList( + [RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)]) + self.context_repo_layers = nn.ModuleList( + [RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)]) + self.repo_layers = nn.ModuleList( + [RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_blocks)]) + + # Refiner blocks + block_kwargs = dict(num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, use_rope=True, qk_rms_norm=self.qk_rms_norm, **op_kwargs) + self.noise_refiner = nn.ModuleList( + [UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_refiner_blocks)]) + self.context_refiner = nn.ModuleList( + [UnifiedTransformerBlock(model_channels, modulation=False, **block_kwargs) for _ in range(num_refiner_blocks)]) + + self.cam_refiner = MLP(self.cam_channels, model_channels, model_channels, **op_kwargs) + + self.blocks = nn.ModuleList( + [UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_blocks)]) + + self.shift_table = nn.Parameter(torch.empty(1, 2, model_channels, **factory_kwargs)) + self.out_layer = operations.Linear(model_channels, out_channels, **factory_kwargs) + self.cam_out_layer = operations.Linear(model_channels, cam_channels, **factory_kwargs) + + def forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, t, context, ref_latents, transformer_options, **kwargs) + + def _forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs): + # x is the unpacked nested latent: [latent (B,8192,in_channels), camera (B,1,cam_channels)]. + # context == feature1. + z, camera = x[0], x[1] + feat1 = context + + h_x = self.input_layer(z) + h_cond = self.cond_embedder(feat1) + if ref_latents is not None and self.cond_embedder2 is not None: + # Flatten the Flux2 VAE latent (B,128,h,w) to a token sequence and front-pad to feat1's length + # (the pad count = feat1's prefix tokens: DINOv3 cls + registers), then add to the context. + feat2 = ref_latents[0].flatten(2).transpose(1, 2) + feat2 = F.pad(feat2, (0, 0, feat1.shape[1] - feat2.shape[1], 0)) + h_cond = h_cond + self.cond_embedder2(feat2.to(h_cond.dtype)) + t_emb = self.t_embedder(t) + t_mod = self.adaLN_modulation(t_emb) if self.share_mod else t_emb + + h_x = h_x + self.pos_emb.to(z) + + for i, block in enumerate(self.noise_refiner): + h_x = block(h_x, mod=t_mod, rotary_emb=self.noise_repo_layers[i](h_x), transformer_options=transformer_options) + + for i, block in enumerate(self.context_refiner): + h_cond = block(h_cond, mod=None, rotary_emb=self.context_repo_layers[i](h_cond), transformer_options=transformer_options) + + cam = camera.to(z) + h_cam = self.cam_refiner(cam) + h = torch.cat([h_x, h_cond, h_cam], dim=1) + + for i, block in enumerate(self.blocks): + h = block(h, mod=t_mod, rotary_emb=self.repo_layers[i](h), transformer_options=transformer_options) + + h_x = F.layer_norm(h[:, :z.shape[1]].float(), h.shape[-1:]).to(z) + h_cam = F.layer_norm(h[:, -cam.shape[1]:].float(), h.shape[-1:]).to(z) + + shift, scale = (comfy.model_management.cast_to(self.shift_table, t_emb.dtype, t_emb.device) + t_emb.unsqueeze(1)).chunk(2, dim=1) + scale = 1 + scale + h_x = torch.addcmul(shift, h_x, scale) + h_cam = torch.addcmul(shift, h_cam, scale) + + return self.out_layer(h_x), self.cam_out_layer(h_cam) diff --git a/comfy/ldm/triposplat/preview.py b/comfy/ldm/triposplat/preview.py new file mode 100644 index 000000000..6a942bb53 --- /dev/null +++ b/comfy/ldm/triposplat/preview.py @@ -0,0 +1,91 @@ +# Live preview for TripoSplat: decode an x0 estimate into a coarse gaussian splat and render it with a perspective orbit camera. +import numpy as np +from PIL import Image + +_C0 = 0.28209479177387814 +_LATENT_TOKENS = 8192 # q_token_length +_LATENT_CH = 16 # in_channels +_OBJECT_TO_VIEWER = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], np.float32) # object frame -> viewer Y-up frame + + +def _view_matrix(yaw_deg, pitch_deg): + y, p = np.radians(yaw_deg), np.radians(pitch_deg) + Ry = np.array([[np.cos(y), 0, np.sin(y)], [0, 1, 0], [-np.sin(y), 0, np.cos(y)]], np.float32) + Rx = np.array([[1, 0, 0], [0, np.cos(p), -np.sin(p)], [0, np.sin(p), np.cos(p)]], np.float32) + return Rx @ Ry + + +def render_splat(xyz, rgb, scale, opacity=None, yaw=35.0, pitch=30.0, size=320, min_px=2, gain=1.0, + max_px=9, min_opacity=0.0, fov=35.0, dist=2.2): + # Project gaussian centers with a perspective camera and paint each as a filled disk whose screen + # radius follows the gaussian's world-space scale, composited with a nearest-wins z-buffer. + # gain scales the footprint (≈ std spanned), `min_px`/`max_px` clamp the on-screen radius. + + pts = xyz.astype(np.float32) @ _OBJECT_TO_VIEWER.T + v = pts @ _view_matrix(yaw, pitch).T + zc = v[:, 2] + dist + keep = zc > 1e-2 + if opacity is not None and min_opacity > 0.0: # culls gaussians with very low opacity + keep = keep & (opacity > min_opacity) + v, zc, scale = v[keep], zc[keep], scale[keep] + col = (np.clip(rgb, 0, 1)[:, :3] * 255).astype(np.uint8)[keep] + if v.shape[0] == 0: + return Image.fromarray(np.zeros((size, size, 3), np.uint8)) + f = (size / 2) / np.tan(np.radians(fov) / 2) + cx = size / 2 + f * v[:, 0] / zc + cy = size / 2 + f * v[:, 1] / zc + radius = np.clip(np.round(f * scale / zc * gain), min_px, max_px).astype(np.int32) + + # Expand each splat to its disk pixels, bucketed by integer radius so it stays vectorized. + px, py, pz, pc = [], [], [], [] + for r in range(int(radius.min()), int(radius.max()) + 1): + m = radius == r + if not m.any(): + continue + dy, dx = np.mgrid[-r:r + 1, -r:r + 1] + disk = (dx * dx + dy * dy) <= r * r + ox, oy = dx[disk], dy[disk] + px.append((cx[m, None] + ox).ravel()) + py.append((cy[m, None] + oy).ravel()) + pz.append(np.repeat(zc[m], ox.size)) + pc.append(np.repeat(col[m], ox.size, axis=0)) + px, py = np.concatenate(px), np.concatenate(py) + pz, pc = np.concatenate(pz), np.concatenate(pc) + xi = np.clip(px, 0, size - 1).astype(np.int64) + yi = np.clip(py, 0, size - 1).astype(np.int64) + + # Nearest-wins z-buffer: pack (quantized depth, source index), per-pixel min picks the closest + # splat, then decode the winning index back to its color. + pid = yi * size + xi + q = np.clip((pz * 1024.0).astype(np.int64), 0, (1 << 20) - 1) # near = small + key = (q << 32) | np.arange(pid.size, dtype=np.int64) + buf = np.full(size * size, 1 << 62, np.int64) + np.minimum.at(buf, pid, key) + img = np.zeros((size * size, 3), np.uint8) + hit = buf < (1 << 62) + img[hit] = pc[buf[hit] & 0xFFFFFFFF] + return Image.fromarray(img.reshape(size, size, 3)) + + +def _extract_latent(x0): + # x0 from the sampler callback is the nested latent packed to (B, 1, TOKENS*CH + 1*5); + # the plain single-latent case is (B, TOKENS, CH). Return the (B, TOKENS, CH) latent stream. + if x0.ndim == 3 and x0.shape[1] == _LATENT_TOKENS and x0.shape[2] == _LATENT_CH: + return x0 + flat = x0.reshape(x0.shape[0], -1) + return flat[:, :_LATENT_TOKENS * _LATENT_CH].reshape(x0.shape[0], _LATENT_TOKENS, _LATENT_CH) + + +def decode_x0_to_image(decoder, x0, cfg): + # Decode x0 at a coarse octree level / few gaussians and render a preview image. + latent = _extract_latent(x0) + fsm = decoder.first_stage_model + gaussian = fsm.decode(latent.to(decoder.device, decoder.vae_dtype), + num_gaussians=cfg.get("gaussians", 16384), level=cfg.get("level", 5))[0] + xyz = gaussian.get_xyz.float().cpu().numpy() + rgb = gaussian._features_dc.float().cpu().numpy()[:, 0, :] * _C0 + 0.5 + scale = gaussian.get_scaling.float().cpu().numpy().max(axis=1) # per-splat world radius (largest axis) + opacity = gaussian.get_opacity.float().cpu().numpy()[:, 0] + return render_splat(xyz, rgb, scale, opacity=opacity, yaw=cfg.get("yaw", 35.0), pitch=cfg.get("pitch", 30.0), + size=cfg.get("size", 320), min_px=1, gain=1.0, max_px=cfg.get("point_size", 3), + min_opacity=0.01) diff --git a/comfy/ldm/triposplat/vae.py b/comfy/ldm/triposplat/vae.py new file mode 100644 index 000000000..e5ed9fd36 --- /dev/null +++ b/comfy/ldm/triposplat/vae.py @@ -0,0 +1,382 @@ +# TripoSplat gaussian decoder ("VAE"): an octree probability decoder picks point coords, then an +# elastic-gaussian decoder predicts per-point gaussian params. OctreeGaussianDecoder.decode() returns +# a Gaussian. The octree sampler uses the global torch RNG (no generator) like upstream, so seed it for repeatable decodes. +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.model_management +import comfy.ops +from .gaussian import build_gaussian_models +from .model import MultiHeadRMSNorm, MLP, PcdAbsolutePositionEmbedder, attention + + +# Quasi-random sampling utilities (pure functions, dtype/device-agnostic) + +PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + + +def radical_inverse(base, n): + val = 0 + inv_base = 1.0 / base + inv_base_n = inv_base + while n > 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[i], n) for i in range(dim)] + + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + + +def sample_probs(probs, counts, generator=None): + # Systematic resampling: distribute counts[r] draws across the P bins of row r + batch_shape = counts.shape + R = counts.numel() + P = probs.size(-1) + device = probs.device + probs = probs.reshape(R, P).to(torch.float32).clamp_min(0) + counts = counts.reshape(R).to(device=device, dtype=torch.long) + + row_sums = probs.sum(1, keepdim=True) + probs = torch.where(row_sums == 0, probs.new_tensor(1.0 / P), probs / row_sums.clamp_min(1)) + cdf = probs.cumsum(dim=1).clamp(max=1.0 - 1e-12) + + Nmax = int(counts.max()) + if Nmax == 0: + return counts.new_zeros(*batch_shape, P) + cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1) + grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax) + u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded) + idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1) + weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r] + out = torch.zeros(R, P, dtype=torch.float32, device=device) + out.scatter_add_(1, idx, weight) + return out.to(torch.long).view(*batch_shape, P) + + +class MultiHeadAttention(nn.Module): + def __init__(self, channels, num_heads, ctx_channels=None, type="self", qkv_bias=True, qk_rms_norm=False, + dtype=None, device=None, operations=None): + super().__init__() + assert channels % num_heads == 0 + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.qk_rms_norm = qk_rms_norm + if self._type == "self": + self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device) + else: + self.to_q = operations.Linear(channels, channels, bias=qkv_bias, dtype=dtype, device=device) + self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, dtype=dtype, device=device) + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device) + self.to_out = operations.Linear(channels, channels, dtype=dtype, device=device) + + def forward(self, x, context=None): + B, L, C = x.shape + if self._type == "self": + q, k, v = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1).unbind(dim=2) + else: + Lkv = context.shape[1] + q = self.to_q(x).reshape(B, L, self.num_heads, -1) + k, v = self.to_kv(context).reshape(B, Lkv, 2, self.num_heads, -1).unbind(dim=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + h = attention(q, k, v) + return self.to_out(h.reshape(B, L, -1)) + + +# Octree probability decoder + +class LevelEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256, max_period=1024, + dtype=None, device=None, operations=None): + super().__init__() + self.mlp = nn.Sequential( + operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + + @staticmethod + def level_embedding(t, dim, max_period=1024): + half = dim // 2 + freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] * 2 * torch.pi + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + emb = self.level_embedding(t, self.frequency_embedding_size, self.max_period) + return self.mlp(emb.to(self.mlp[0].weight.dtype)) + + +class ModulatedTransformerCrossOnlyBlock(nn.Module): + def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, share_mod=False, + qk_rms_norm_cross=True, qkv_bias=True, dtype=None, device=None, operations=None): + super().__init__() + self.share_mod = share_mod + self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, + type="cross", qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations) + self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)) + + def forward(self, x, mod, context): + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1)) + x = torch.addcmul(x, self.cross_attn(h, context), gate_msa.unsqueeze(1)) + h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1)) + x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1)) + return x + + +class OctreeProbabilityFixedlenDecoder(nn.Module): + # Cross-attention transformer over octree coords -> per-node 8-way child occupancy logits. + def __init__(self, model_channels=1024, cond_channels=16, num_blocks=4, num_heads=16, + num_head_channels=64, mlp_ratio=4.0, share_mod=True, + qk_rms_norm_cross=True, dtype=None, device=None, operations=None): + super().__init__() + self.model_channels = model_channels + self.cond_channels = cond_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.share_mod = share_mod + self.qk_rms_norm_cross = qk_rms_norm_cross + self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device) + self.l_embedder = LevelEmbedder(model_channels, dtype=dtype, device=device, operations=operations) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, dtype=dtype, device=device)) + if cond_channels is not None: + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossOnlyBlock( + model_channels, ctx_channels=cond_channels, num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, qk_rms_norm_cross=self.qk_rms_norm_cross, + share_mod=self.share_mod, dtype=dtype, device=device, operations=operations) + for _ in range(num_blocks) + ]) + self.out_proj = operations.Linear(model_channels, 8, dtype=dtype, device=device) + self.in_proj = operations.Linear(3, model_channels, dtype=dtype, device=device) + self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2") + + def forward(self, x, l, cond): + d = next(self.parameters()).dtype + B, L, _ = x.shape + h = self.in_proj(x.to(d)) + self.pos_embedder(x.reshape(-1, 3)).reshape(B, L, -1).to(d) + h = self.input_layer(h) + l_emb = self.l_embedder(l) + if self.share_mod: + l_emb = self.adaLN_modulation(l_emb) + cond = cond.to(d) + for block in self.blocks: + h = block(h, l_emb, cond) + h = F.layer_norm(h.float(), h.shape[-1:]).to(d) + logits = self.out_proj(h) + return {"logits": logits, "probs": torch.softmax(logits, dim=-1)} + + @staticmethod + def sample(model, cond, num_points, level, temperature=1.0, generator=None): + B = cond.shape[0] + device = cond.device + child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]], + dtype=torch.long, device=device) + prev_coords_int = torch.zeros(B, 1, 3, dtype=torch.long, device=device) + prev_counts = torch.full((B, 1), num_points, dtype=torch.long, device=device) + prev_log_probs = torch.zeros(B, 1, dtype=torch.float32, device=device) + batch_indices_range = torch.arange(B, device=device).unsqueeze(1) + + for lv in range(1, level + 1): + res_p = 1 << (lv - 1) + res = 1 << lv + parent_coords_norm = (prev_coords_int.to(torch.float32) + 0.5) / res_p + res_tensor = torch.full((B,), res, dtype=torch.long, device=device) + pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature + pred_probs = torch.softmax(pred_logits, dim=-1) + pred_log_probs = torch.log_softmax(pred_logits, dim=-1) + sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2) + pred_log_probs = pred_log_probs.flatten(1, 2) + prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1) + child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2) + mask = sampled > 0 + max_valid = mask.sum(dim=1).max().item() + scatter_indices = mask.cumsum(dim=1) - 1 + valid_scatter_indices = scatter_indices[mask] + valid_batch_indices = batch_indices_range.expand_as(mask)[mask] + next_prev_coords_int = torch.zeros(B, max_valid, 3, dtype=child_coords_int.dtype, device=device) + next_prev_coords_int[valid_batch_indices, valid_scatter_indices] = child_coords_int[mask] + next_prev_counts = torch.zeros(B, max_valid, dtype=sampled.dtype, device=device) + next_prev_counts[valid_batch_indices, valid_scatter_indices] = sampled[mask] + next_prev_log_probs = torch.zeros(B, max_valid, dtype=prev_log_probs.dtype, device=device) + next_prev_log_probs[valid_batch_indices, valid_scatter_indices] = (prev_log_probs_expanded + pred_log_probs)[mask] + prev_coords_int = next_prev_coords_int + prev_counts = next_prev_counts + prev_log_probs = next_prev_log_probs + + res = 1 << level + prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points) + coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1) + rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device) + coords_norm = (coords_int.to(torch.float32) + rand) / res + return {"points": coords_norm, "log_probs": prev_log_probs} + + +# Elastic gaussian decoder + +class TransformerCrossBlock(nn.Module): + def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, + qk_rms_norm=True, qk_rms_norm_cross=True, qkv_bias=True, + dtype=None, device=None, operations=None): + super().__init__() + self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) + self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.self_attn = MultiHeadAttention(channels, num_heads=num_heads, type="self", qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm, dtype=dtype, device=device, operations=operations) + self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, type="cross", + qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations) + self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations) + + def forward(self, x, context): + x = x + self.self_attn(self.norm1(x)) + x = x + self.cross_attn(self.norm2(x), context) + x = x + self.mlp(self.norm3(x)) + return x + + +class ElasticGaussianFixedlenDecoder(nn.Module): + # Cross-attention transformer over sampled octree points -> per-point gaussian params. + def __init__(self, in_channels=3, model_channels=1024, cond_channels=16, num_blocks=16, num_heads=16, + num_head_channels=64, mlp_ratio=4.0, *, representation_config=None, + qk_rms_norm=True, qk_rms_norm_cross=True, dtype=None, device=None, operations=None): + super().__init__() + self.rep_config = representation_config or dict( + lr=dict(_xyz=1.0, _features_dc=1.0, _opacity=1.0, _scaling=1.0, _rotation=0.1), + perturb_offset=True, perturbe_size=1.5, offset_scale=0.05, num_gaussians=32, + filter_kernel_size_3d=0.0009, scaling_bias=0.004, opacity_bias=0.1, + scaling_activation="softplus", + ) + self.out_channels = self._calc_layout() + self.model_channels = model_channels + self.cond_channels = cond_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device) + if cond_channels is not None: + self.blocks = nn.ModuleList([ + TransformerCrossBlock(model_channels, ctx_channels=cond_channels, + num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, + qk_rms_norm=qk_rms_norm, qk_rms_norm_cross=qk_rms_norm_cross, + dtype=dtype, device=device, operations=operations) + for _ in range(num_blocks) + ]) + self.in_proj = operations.Linear(in_channels, model_channels, dtype=dtype, device=device) + self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2") + self.out_proj = operations.Linear(model_channels, self.out_channels, dtype=dtype, device=device) + self._build_perturbation() + + def _calc_layout(self): + ng = self.rep_config['num_gaussians'] + self.layout = { + '_xyz': {'shape': (ng, 3), 'size': ng * 3}, + '_features_dc': {'shape': (ng, 1, 3), 'size': ng * 3}, + '_scaling': {'shape': (ng, 3), 'size': ng * 3}, + '_rotation': {'shape': (ng, 4), 'size': ng * 4}, + '_opacity': {'shape': (ng, 1), 'size': ng}, + } + self.layout['_offset_scale'] = {'shape': (ng, 1), 'size': ng} + start = 0 + for k, v in self.layout.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + return start + + def _build_perturbation(self): + ng = self.rep_config['num_gaussians'] + perturbation = torch.tensor([hammersley_sequence(3, i, ng) for i in range(ng)]).float() + perturbation = torch.atanh((perturbation * 2 - 1) / self.rep_config['perturbe_size']) + self.register_buffer('points_offset_perturbation', perturbation) + base = torch.tensor(self.rep_config['offset_scale']) + self.register_buffer('base_offset_scale', torch.log(torch.exp(base) - 1.0)) + + def _get_offset(self, h): + B = h.shape[0] + r = self.layout['_offset_scale']['range'] + _offset_scale = F.softplus( + h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_offset_scale']['shape']) + + comfy.model_management.cast_to(self.base_offset_scale, h.dtype, h.device)) + + r = self.layout['_xyz']['range'] + offset = h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_xyz']['shape']) + offset = offset * self.rep_config['lr']['_xyz'] + if self.rep_config['perturb_offset']: + offset = offset + comfy.model_management.cast_to(self.points_offset_perturbation, offset.dtype, offset.device) + offset = torch.tanh(offset) * 0.5 * self.rep_config['perturbe_size'] + offset = offset * _offset_scale + return offset + + def forward(self, x=None, cond=None): + pcd = x["points"] + d = next(self.parameters()).dtype + B, L, _ = pcd.shape + h = self.in_proj(pcd.to(d)) + self.pos_embedder(pcd.reshape(-1, 3)).reshape(B, L, -1).to(d) + h = self.input_layer(h) + cond = cond.to(d) + for block in self.blocks: + h = block(h, cond) + h = F.layer_norm(h.float(), h.shape[-1:]).to(h.dtype) + return {"features": self.out_proj(h)} + + +# Combined octree gaussian decoder (comfy first-stage model) + +class OctreeGaussianDecoder(nn.Module): + _MAX_VOXEL_LEVEL = 8 + + def __init__(self, dtype=None, device=None, operations=None): + super().__init__() + if operations is None: + operations = comfy.ops.disable_weight_init + self.octree = OctreeProbabilityFixedlenDecoder(dtype=dtype, device=device, operations=operations) + self.gs = ElasticGaussianFixedlenDecoder(dtype=dtype, device=device, operations=operations) + + @property + def gaussians_per_point(self) -> int: + return self.gs.rep_config['num_gaussians'] + + def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None): + # level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews. + # generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG. + level = self._MAX_VOXEL_LEVEL if level is None else level + num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point) + points_pred = OctreeProbabilityFixedlenDecoder.sample( + self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator, + ) + pred = self.gs(x=points_pred, cond=latent) + return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item diff --git a/comfy/model_base.py b/comfy/model_base.py index e819395f3..f1befab77 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -47,6 +47,7 @@ import comfy.ldm.wan.model_animate import comfy.ldm.wan.ar_model import comfy.ldm.wan.model_wandancer import comfy.ldm.hunyuan3d.model +import comfy.ldm.triposplat.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model import comfy.ldm.chroma_radiance.model @@ -1812,6 +1813,24 @@ class Hunyuan3Dv2_1(BaseModel): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out +class TripoSplat(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.triposplat.model.LatentSeqMMFlowModel) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) # DINOv3 token sequence -> cross-attention context. + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + ref_latents = kwargs.get("reference_latents", None) # Flux2 VAE image latent -> additive second conditioning. + if ref_latents is not None: + out['ref_latents'] = comfy.conds.CONDList(list(ref_latents)) + latent_shapes = kwargs.get("latent_shapes", None) # {latent, camera} nested latent + if latent_shapes is not None: + out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) + return out + + class HiDream(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 2043c96ed..8e1ad2d3c 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -355,6 +355,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["use_x0"] = True else: dit_config["use_x0"] = False + if "{}__sequential__".format(key_prefix) in state_dict_keys: # sequential txt_ids + dit_config["use_sequential_txt_ids"] = True + else: + dit_config["use_sequential_txt_ids"] = False else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys @@ -718,6 +722,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys return dit_config + if '{}cam_out_layer.weight'.format(key_prefix) in state_dict_keys and '{}repo_layers.0.final_map.weight'.format(key_prefix) in state_dict_keys: # TripoSplat + return {"image_model": "triposplat"} + if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1 return {"image_model": "hidream_o1"} diff --git a/comfy/sd.py b/comfy/sd.py index 7d6998aff..08a48b5fa 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae +import comfy.ldm.triposplat.vae import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.cogvideo.vae import comfy.ldm.hunyuan_video.vae @@ -908,6 +909,16 @@ class VAE: #Force cast it for --disable-dynamic-vram users until there is a true core fix. if not comfy.memory_management.aimdo_enabled: self.disable_offload = True + elif "gs.base_offset_scale" in sd and "octree.out_proj.weight" in sd: # TripoSplat octree gaussian decoder + self.first_stage_model = comfy.ldm.triposplat.vae.OctreeGaussianDecoder() + self.latent_channels = 16 + self.latent_dim = 1 + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + # The generic VAE.encode/decode path isn't used: VAEDecodeTripoSplat calls the gaussian + # decoder directly (structured GaussianSplat objects, not a tensor and reserves VRAM itself from num_gaussians. + def _no_generic_io(*args, **kwargs): + raise RuntimeError("TripoSplat gaussian decoder: use the 'TripoSplat Decode' (VAEDecodeTripoSplat)") + self.memory_used_encode = self.memory_used_decode = _no_generic_io else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 7ad8f2e4a..223601319 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1547,6 +1547,30 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2): latent_format = latent_formats.Hunyuan3Dv2mini +class TripoSplat(supported_models_base.BASE): + # Image -> 3D gaussian splat flow denoiser + unet_config = { + "image_model": "triposplat", + } + + unet_extra_config = {} + + sampling_settings = { + "shift": 3.0, + } + + memory_usage_factor = 0.6 + + latent_format = latent_formats.TripoSplat + + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.TripoSplat(self, device=device) + + def clip_target(self, state_dict={}): + return None + class HiDream(supported_models_base.BASE): unet_config = { "image_model": "hidream", @@ -2210,6 +2234,7 @@ models = [ Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, + TripoSplat, HiDream, HiDreamO1, Chroma, diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index e0a585b10..294ad425e 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -5,7 +5,7 @@ from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from ._input_impl import VideoFromFile, VideoFromComponents -from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D +from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, SPLAT, File3D from . import _io_public as io from . import _ui_public as ui from comfy_execution.utils import get_executing_context @@ -143,6 +143,7 @@ class Types: VideoComponents = VideoComponents MESH = MESH VOXEL = VOXEL + SPLAT = SPLAT File3D = File3D diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index 451e9526e..8fff52c16 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -65,6 +65,12 @@ class VideoInput(ABC): buffer.seek(0) return buffer + def get_active_trim_window(self) -> tuple[float, float]: + """Return the active trim as ``(start_time, duration)`` in seconds (start_time normalized + to ``>= 0``; ``duration == 0`` means "until the end"). Default: no trim; trimmable subclasses override. + """ + return 0.0, 0.0 + # Provide a default implementation, but subclasses can provide optimized versions # if possible. def get_dimensions(self) -> tuple[int, int]: diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 99e67d363..4a12ff9c1 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -75,6 +75,12 @@ class VideoFromFile(VideoInput): self.__file.seek(0) return self.__file + def get_active_trim_window(self) -> tuple[float, float]: + start_time = self.__start_time + if start_time < 0: + start_time = max(self._get_raw_duration() + start_time, 0.0) + return float(start_time), float(self.__duration) + def get_dimensions(self) -> tuple[int, int]: """ Returns the dimensions of the video input. diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index e03bafcde..a3aa508ce 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) from comfy_execution.graph_utils import ExecutionBlocker -from ._util import MESH, VOXEL, SVG as _SVG, File3D +from ._util import MESH, VOXEL, SPLAT, SVG as _SVG, File3D class FolderType(str, Enum): @@ -684,6 +684,10 @@ class Voxel(ComfyTypeIO): class Mesh(ComfyTypeIO): Type = MESH +@comfytype(io_type="SPLAT") +class Splat(ComfyTypeIO): + Type = SPLAT + @comfytype(io_type="FILE_3D") class File3DAny(ComfyTypeIO): @@ -2320,6 +2324,7 @@ __all__ = [ "LossMap", "Voxel", "Mesh", + "Splat", "File3DAny", "File3DGLB", "File3DGLTF", diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index 115baf392..b27f5a97e 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,5 +1,5 @@ from .video_types import VideoContainer, VideoCodec, VideoComponents -from .geometry_types import VOXEL, MESH, File3D +from .geometry_types import VOXEL, MESH, SPLAT, File3D from .image_types import SVG __all__ = [ @@ -9,6 +9,7 @@ __all__ = [ "VideoComponents", "VOXEL", "MESH", + "SPLAT", "File3D", "SVG", ] diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py index cdde60b10..84a18d69a 100644 --- a/comfy_api/latest/_util/geometry_types.py +++ b/comfy_api/latest/_util/geometry_types.py @@ -11,13 +11,32 @@ class VOXEL: self.data = data +class SPLAT: + """A batch of 3D Gaussian splats in render-ready (activated, world-space) form. + + Tensors are (B, N, ...) and zero-padded to a common N across the batch; `counts` (B,) holds the + real per-item lengths (None when rows are uniform and no slicing is needed). SH coefficients are + stored as (B, N, K, 3) with K = (sh_degree + 1)**2; the DC (diffuse) term is sh[..., 0, :]. + """ + + def __init__(self, positions: torch.Tensor, scales: torch.Tensor, rotations: torch.Tensor, + opacities: torch.Tensor, sh: torch.Tensor, counts: torch.Tensor | None = None): + self.positions = positions # (B, N, 3) world-space centers + self.scales = scales # (B, N, 3) linear (positive) per-axis std + self.rotations = rotations # (B, N, 4) quaternion wxyz (normalized) + self.opacities = opacities # (B, N, 1) in [0, 1] + self.sh = sh # (B, N, K, 3) spherical-harmonic color coefficients + self.counts = counts # (B,) real lengths, or None + + class MESH: def __init__(self, vertices: torch.Tensor, faces: torch.Tensor, uvs: torch.Tensor | None = None, vertex_colors: torch.Tensor | None = None, texture: torch.Tensor | None = None, vertex_counts: torch.Tensor | None = None, - face_counts: torch.Tensor | None = None): + face_counts: torch.Tensor | None = None, + unlit: bool = False): assert (vertex_counts is None) == (face_counts is None), \ "vertex_counts and face_counts must be provided together (both or neither)" @@ -30,6 +49,8 @@ class MESH: # these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed. self.vertex_counts = vertex_counts self.face_counts = face_counts + # Render flat / emissive (no scene lighting) when saved, e.g. for gaussian-splat-derived meshes. + self.unlit = unlit class File3D: diff --git a/comfy_api_nodes/apis/bfl.py b/comfy_api_nodes/apis/bfl.py index f0665fa09..2ad651122 100644 --- a/comfy_api_nodes/apis/bfl.py +++ b/comfy_api_nodes/apis/bfl.py @@ -1,71 +1,71 @@ from enum import Enum -from typing import Any, Dict, Optional +from typing import Any -from pydantic import BaseModel, Field, confloat, conint - - -class BFLOutputFormat(str, Enum): - png = 'png' - jpeg = 'jpeg' +from pydantic import BaseModel, Field class BFLFluxExpandImageRequest(BaseModel): - prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image') - bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image') - left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image') - right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image') - steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') - guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand') + prompt: str = Field(...) + prompt_upsampling: bool | None = Field(None) + seed: int | None = Field(None) + top: int = Field(...) + bottom: int = Field(...) + left: int = Field(...) + right: int = Field(...) + steps: int = Field(...) + guidance: float = Field(...) + safety_tolerance: int = Field(6) + output_format: str = Field("png") + image: str = Field(None, description="A Base64-encoded string representing the image you wish to expand") class BFLFluxFillImageRequest(BaseModel): - prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' + prompt: str = Field(...) + prompt_upsampling: bool | None = Field(None) + seed: int | None = Field(None) + steps: int = Field(...) + guidance: float = Field(...) + safety_tolerance: int = Field(6) + output_format: str = Field("png") + image: str = Field( + None, description="Base64-encoded string representing the image to modify. Can contain alpha mask if desired.", ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') - guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' + mask: str = Field( + None, description="Base64-encoded string representing the mask of the areas you wish to modify." ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] + + +class BFLFluxEraseRequest(BaseModel): + image: str = Field(..., description="A Base64-encoded string representing the image to erase from.") + mask: str = Field( + ..., + description="A Base64-encoded black/white mask matching the input dimensions; " + "white (255) marks areas to remove, black (0) marks areas to preserve.", ) - image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.') - mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.') + dilate_pixels: int = Field(10) + output_format: str = Field("png") + + +class BFLFluxVTORequest(BaseModel): + prompt: str = Field( + ..., description="Natural-language styling instruction. Required field, but may be an empty string." + ) + person: str = Field(..., description="A Base64-encoded string representing the person image.") + garment: str = Field(..., description="A Base64-encoded string representing the garment reference image.") + seed: int | None = Field(None) + safety_tolerance: int = Field(5) + output_format: str = Field("png") class BFLFluxProGenerateRequest(BaseModel): - prompt: str = Field(..., description='The text prompt for image generation.') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.') - height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format') - # image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field( - # None, description='Blend between the prompt and the image prompt.' - # ) + prompt: str = Field(...) + prompt_upsampling: bool | None = Field(None) + seed: int | None = Field(None) + width: int = Field(1024, description="Must be a multiple of 32.") + height: int = Field(768, description="Must be a multiple of 32.") + safety_tolerance: int = Field(6) + output_format: str = Field("png") + image_prompt: str | None = Field(None, description="Optional image to remix in base64 format") class Flux2ProGenerateRequest(BaseModel): @@ -83,55 +83,37 @@ class Flux2ProGenerateRequest(BaseModel): input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation") input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation") input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation") - safety_tolerance: int | None = Field( - 5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5 - ) - output_format: str | None = Field( - "png", description="Output format for the generated image. Can be 'jpeg' or 'png'." - ) + safety_tolerance: int = Field(5) + output_format: str = Field("png") class BFLFluxKontextProGenerateRequest(BaseModel): - prompt: str = Field(..., description='The text prompt for what you wannt to edit.') - input_image: Optional[str] = Field(None, description='Image to edit in base64 format') - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process') - steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=2)] = Field( - 2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) + prompt: str = Field(...) + input_image: str | None = Field(None, description="Image to edit in base64 format") + seed: int | None = Field(None) + guidance: float = Field(...) + steps: int = Field(...) + safety_tolerance: int = Field(2) + output_format: str = Field("png") + aspect_ratio: str | None = Field(None) + prompt_upsampling: bool | None = Field(None) class BFLFluxProUltraGenerateRequest(BaseModel): - prompt: str = Field(..., description='The text prompt for image generation.') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.') - image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format') - image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field( - None, description='Blend between the prompt and the image prompt.' - ) + prompt: str = Field(...) + prompt_upsampling: bool | None = Field(None) + seed: int | None = Field(None) + aspect_ratio: str | None = Field(None) + safety_tolerance: int = Field(6) + output_format: str = Field("png") + raw: bool | None = Field(None) + image_prompt: str | None = Field(None, description="Optional image to remix in base64 format") + image_prompt_strength: float | None = Field(None) class BFLFluxProGenerateResponse(BaseModel): - id: str = Field(..., description="The unique identifier for the generation task.") - polling_url: str = Field(..., description="URL to poll for the generation result.") + id: str = Field(...) + polling_url: str = Field(...) cost: float | None = Field(None, description="Price in cents") @@ -145,7 +127,7 @@ class BFLStatus(str, Enum): class BFLFluxStatusResponse(BaseModel): - id: str = Field(..., description="The unique identifier for the generation task.") - status: BFLStatus = Field(..., description="The status of the task.") - result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).") - progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0) + id: str = Field(...) + status: BFLStatus = Field(...) + result: dict[str, Any] | None = Field(None) + progress: float | None = Field(None, ge=0.0, le=1.0) diff --git a/comfy_api_nodes/nodes_anthropic.py b/comfy_api_nodes/nodes_anthropic.py index 7805c96ce..87a870553 100644 --- a/comfy_api_nodes/nodes_anthropic.py +++ b/comfy_api_nodes/nodes_anthropic.py @@ -155,7 +155,7 @@ class ClaudeNode(IO.ComfyNode): return IO.Schema( node_id="ClaudeNode", display_name="Anthropic Claude", - category="text/partner/Anthropic", + category="partner/text/Anthropic", essentials_category="Text Generation", description="Generate text responses with Anthropic's Claude models. " "Provide a text prompt and optionally one or more images for multimodal context.", diff --git a/comfy_api_nodes/nodes_beeble.py b/comfy_api_nodes/nodes_beeble.py index f1082884c..d863c2130 100644 --- a/comfy_api_nodes/nodes_beeble.py +++ b/comfy_api_nodes/nodes_beeble.py @@ -206,7 +206,7 @@ class BeebleSwitchXVideoEdit(IO.ComfyNode): return IO.Schema( node_id="BeebleSwitchXVideoEdit", display_name="Beeble SwitchX Video Edit", - category="video/partner/Beeble", + category="partner/video/Beeble", description=( "Edit a video with Beeble SwitchX. Switches anything in the scene (background, " "lighting, costume) while preserving the original subject's pixels and motion. " @@ -302,7 +302,7 @@ class BeebleSwitchXImageEdit(IO.ComfyNode): return IO.Schema( node_id="BeebleSwitchXImageEdit", display_name="Beeble SwitchX Image Edit", - category="image/partner/Beeble", + category="partner/image/Beeble", description=( "Edit a single image with Beeble SwitchX. Switches anything in the scene " "(background, lighting, costume) while preserving the original subject's pixels. " diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index f1a5dc5f0..79961ff9d 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -4,17 +4,20 @@ from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.bfl import ( + BFLFluxEraseRequest, BFLFluxExpandImageRequest, BFLFluxFillImageRequest, BFLFluxKontextProGenerateRequest, BFLFluxProGenerateResponse, BFLFluxProUltraGenerateRequest, BFLFluxStatusResponse, + BFLFluxVTORequest, BFLStatus, Flux2ProGenerateRequest, ) from comfy_api_nodes.util import ( ApiEndpoint, + convert_mask_to_image, download_url_to_image_tensor, get_number_of_images, poll_op, @@ -22,19 +25,11 @@ from comfy_api_nodes.util import ( sync_op, tensor_to_base64_string, validate_aspect_ratio_string, + validate_image_dimensions, validate_string, ) -def convert_mask_to_image(mask: Input.Image): - """ - Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. - """ - mask = mask.unsqueeze(-1) - mask = torch.cat([mask] * 3, dim=-1) - return mask - - class FluxProUltraImageNode(IO.ComfyNode): @classmethod @@ -42,7 +37,7 @@ class FluxProUltraImageNode(IO.ComfyNode): return IO.Schema( node_id="FluxProUltraImageNode", display_name="Flux 1.1 [pro] Ultra Image", - category="image/partner/BFL", + category="partner/image/BFL", description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.", inputs=[ IO.String.Input( @@ -160,7 +155,7 @@ class FluxKontextProImageNode(IO.ComfyNode): return IO.Schema( node_id=cls.NODE_ID, display_name=cls.DISPLAY_NAME, - category="image/partner/BFL", + category="partner/image/BFL", description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.", inputs=[ IO.String.Input( @@ -282,7 +277,7 @@ class FluxProExpandNode(IO.ComfyNode): return IO.Schema( node_id="FluxProExpandNode", display_name="Flux.1 Expand Image", - category="image/partner/BFL", + category="partner/image/BFL", description="Outpaints image based on prompt.", inputs=[ IO.Image.Input("image"), @@ -419,7 +414,7 @@ class FluxProFillNode(IO.ComfyNode): return IO.Schema( node_id="FluxProFillNode", display_name="Flux.1 Fill Image", - category="image/partner/BFL", + category="partner/image/BFL", description="Inpaints image based on mask and prompt.", inputs=[ IO.Image.Input("image"), @@ -519,6 +514,163 @@ class FluxProFillNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) +class FluxEraseNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="FluxEraseNode", + display_name="Flux Erase Image", + category="partner/image/BFL", + description="Removes the masked object from an image and reconstructs the background. " + "Paint the mask over what you want to erase.", + inputs=[ + IO.Image.Input("image"), + IO.Mask.Input("mask", tooltip="White areas are removed; black areas are preserved."), + IO.Int.Input( + "dilate_pixels", + default=10, + min=0, + max=25, + tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"range_usd","min_usd":0.03,"max_usd":0.06,"format":{"approximate":true}}""", + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + mask: Input.Image, + dilate_pixels: int = 10, + ) -> IO.NodeOutput: + validate_image_dimensions(image, min_width=256, min_height=256) + mask = resize_mask_to_image(mask, image) + mask = tensor_to_base64_string(convert_mask_to_image(mask)) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/v1/flux-tools/erase-v1", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxEraseRequest( + image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed + mask=mask, + dilate_pixels=dilate_pixels, + ), + ) + + def price_extractor(_r: BaseModel) -> float | None: + return None if initial_response.cost is None else initial_response.cost / 100 + + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + price_extractor=price_extractor, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) + + +class FluxVTONode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="FluxVTONode", + display_name="Flux Virtual Try-On", + category="partner/image/BFL", + description="Virtual try-on: dresses the person in the provided garment.", + inputs=[ + IO.Image.Input("person", tooltip="Image of the person to dress."), + IO.Image.Input("garment", tooltip="Image of the garment to apply."), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Optional natural-language styling instruction (e.g. how the garment should fit).", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"range_usd","min_usd":0.0375,"max_usd":0.075,"format":{"approximate":true}}""", + ), + ) + + @classmethod + async def execute( + cls, + person: Input.Image, + garment: Input.Image, + prompt: str = "", + seed: int = 0, + ) -> IO.NodeOutput: + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/v1/flux-tools/vto-v1", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxVTORequest( + prompt=prompt, + person=tensor_to_base64_string(person[:, :, :, :3]), + garment=tensor_to_base64_string(garment[:, :, :, :3]), + seed=seed, + ), + ) + + def price_extractor(_r: BaseModel) -> float | None: + return None if initial_response.cost is None else initial_response.cost / 100 + + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + price_extractor=price_extractor, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) + + class Flux2ProImageNode(IO.ComfyNode): NODE_ID = "Flux2ProImageNode" @@ -545,7 +697,7 @@ class Flux2ProImageNode(IO.ComfyNode): return IO.Schema( node_id=cls.NODE_ID, display_name=cls.DISPLAY_NAME, - category="image/partner/BFL", + category="partner/image/BFL", description="Generates images synchronously based on prompt and resolution.", inputs=[ IO.String.Input( @@ -716,7 +868,7 @@ class Flux2ImageNode(IO.ComfyNode): return IO.Schema( node_id="Flux2ImageNode", display_name="Flux.2 Image", - category="image/partner/BFL", + category="partner/image/BFL", description="Generate images via Flux.2 [pro] or Flux.2 [max] from a prompt and optional reference images.", inputs=[ IO.String.Input( @@ -853,6 +1005,8 @@ class BFLExtension(ComfyExtension): FluxKontextMaxImageNode, FluxProExpandNode, FluxProFillNode, + FluxEraseNode, + FluxVTONode, Flux2ProImageNode, Flux2MaxImageNode, Flux2ImageNode, diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py index 53e763210..69b0233af 100644 --- a/comfy_api_nodes/nodes_bria.py +++ b/comfy_api_nodes/nodes_bria.py @@ -31,7 +31,7 @@ class BriaImageEditNode(IO.ComfyNode): return IO.Schema( node_id="BriaImageEditNode", display_name="Bria FIBO Image Edit", - category="image/partner/Bria", + category="partner/image/Bria", description="Edit images using Bria latest model", inputs=[ IO.Combo.Input("model", options=["FIBO"]), @@ -169,7 +169,7 @@ class BriaRemoveImageBackground(IO.ComfyNode): return IO.Schema( node_id="BriaRemoveImageBackground", display_name="Bria Remove Image Background", - category="image/partner/Bria", + category="partner/image/Bria", description="Remove the background from an image using Bria RMBG 2.0.", inputs=[ IO.Image.Input("image"), @@ -245,7 +245,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode): return IO.Schema( node_id="BriaRemoveVideoBackground", display_name="Bria Remove Video Background", - category="video/partner/Bria", + category="partner/video/Bria", description="Remove the background from a video using Bria. ", inputs=[ IO.Video.Input("video"), diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 3711bac1d..d8885a7e5 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -368,7 +368,7 @@ class ByteDanceImageNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceImageNode", display_name="ByteDance Image", - category="image/partner/ByteDance", + category="partner/image/ByteDance", description="Generate images using ByteDance models via api based on prompt", inputs=[ IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]), @@ -492,7 +492,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceSeedreamNode", display_name="ByteDance Seedream 4.5 & 5.0", - category="image/partner/ByteDance", + category="partner/image/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ IO.Combo.Input( @@ -754,7 +754,7 @@ class ByteDanceSeedreamNodeV2(IO.ComfyNode): return IO.Schema( node_id="ByteDanceSeedreamNodeV2", display_name="ByteDance Seedream 4.5 & 5.0", - category="image/partner/ByteDance", + category="partner/image/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ IO.String.Input( @@ -920,7 +920,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceTextToVideoNode", display_name="ByteDance Text to Video", - category="video/partner/ByteDance", + category="partner/video/ByteDance", description="Generate video using ByteDance models via api based on prompt", inputs=[ IO.Combo.Input( @@ -1048,7 +1048,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceImageToVideoNode", display_name="ByteDance Image to Video", - category="video/partner/ByteDance", + category="partner/video/ByteDance", description="Generate video using ByteDance models via api based on image and prompt", inputs=[ IO.Combo.Input( @@ -1185,7 +1185,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceFirstLastFrameNode", display_name="ByteDance First-Last-Frame to Video", - category="video/partner/ByteDance", + category="partner/video/ByteDance", description="Generate video using prompt and first and last frames.", inputs=[ IO.Combo.Input( @@ -1333,7 +1333,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceImageReferenceNode", display_name="ByteDance Reference Images to Video", - category="video/partner/ByteDance", + category="partner/video/ByteDance", description="Generate video using prompt and reference images.", inputs=[ IO.Combo.Input( @@ -1576,7 +1576,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ByteDance2TextToVideoNode", display_name="ByteDance Seedance 2.0 Text to Video", - category="video/partner/ByteDance", + category="partner/video/ByteDance", description="Generate video using Seedance 2.0 models based on a text prompt.", inputs=[ IO.DynamicCombo.Input( @@ -1677,7 +1677,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="ByteDance2FirstLastFrameNode", display_name="ByteDance Seedance 2.0 First-Last-Frame to Video", - category="video/partner/ByteDance", + category="partner/video/ByteDance", description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.", inputs=[ IO.DynamicCombo.Input( @@ -1944,7 +1944,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode): return IO.Schema( node_id="ByteDance2ReferenceNode", display_name="ByteDance Seedance 2.0 Reference to Video", - category="video/partner/ByteDance", + category="partner/video/ByteDance", description="Generate, edit, or extend video using Seedance 2.0 with reference images, " "videos, and audio. Supports multimodal reference, video editing, and video extension.", inputs=[ @@ -2241,7 +2241,7 @@ class ByteDanceCreateImageAsset(IO.ComfyNode): return IO.Schema( node_id="ByteDanceCreateImageAsset", display_name="ByteDance Create Image Asset", - category="image/partner/ByteDance", + category="partner/image/ByteDance", description=( "Create a Seedance 2.0 personal image asset. Uploads the input image and " "registers it in the given asset group. If group_id is empty, runs a real-person " @@ -2308,7 +2308,7 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode): return IO.Schema( node_id="ByteDanceCreateVideoAsset", display_name="ByteDance Create Video Asset", - category="video/partner/ByteDance", + category="partner/video/ByteDance", description=( "Create a Seedance 2.0 personal video asset. Uploads the input video and " "registers it in the given asset group. If group_id is empty, runs a real-person " diff --git a/comfy_api_nodes/nodes_bytedance_llm.py b/comfy_api_nodes/nodes_bytedance_llm.py index 007cac45f..cb41defa0 100644 --- a/comfy_api_nodes/nodes_bytedance_llm.py +++ b/comfy_api_nodes/nodes_bytedance_llm.py @@ -144,7 +144,7 @@ class ByteDanceSeedNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceSeedNode", display_name="ByteDance Seed", - category="text/partner/ByteDance", + category="partner/text/ByteDance", essentials_category="Text Generation", description="Generate text responses with ByteDance's Seed 2.0 models. " "Provide a text prompt and optionally one or more images or videos for multimodal context.", diff --git a/comfy_api_nodes/nodes_elevenlabs.py b/comfy_api_nodes/nodes_elevenlabs.py index 37eeb2601..eba578a45 100644 --- a/comfy_api_nodes/nodes_elevenlabs.py +++ b/comfy_api_nodes/nodes_elevenlabs.py @@ -69,7 +69,7 @@ class ElevenLabsSpeechToText(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsSpeechToText", display_name="ElevenLabs Speech to Text", - category="audio/partner/ElevenLabs", + category="partner/audio/ElevenLabs", description="Transcribe audio to text. " "Supports automatic language detection, speaker diarization, and audio event tagging.", inputs=[ @@ -210,7 +210,7 @@ class ElevenLabsVoiceSelector(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsVoiceSelector", display_name="ElevenLabs Voice Selector", - category="audio/partner/ElevenLabs", + category="partner/audio/ElevenLabs", description="Select a predefined ElevenLabs voice for text-to-speech generation.", inputs=[ IO.Combo.Input( @@ -239,7 +239,7 @@ class ElevenLabsTextToSpeech(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsTextToSpeech", display_name="ElevenLabs Text to Speech", - category="audio/partner/ElevenLabs", + category="partner/audio/ElevenLabs", description="Convert text to speech.", inputs=[ IO.Custom(ELEVENLABS_VOICE).Input( @@ -414,7 +414,7 @@ class ElevenLabsAudioIsolation(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsAudioIsolation", display_name="ElevenLabs Voice Isolation", - category="audio/partner/ElevenLabs", + category="partner/audio/ElevenLabs", description="Remove background noise from audio, isolating vocals or speech.", inputs=[ IO.Audio.Input( @@ -459,7 +459,7 @@ class ElevenLabsTextToSoundEffects(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsTextToSoundEffects", display_name="ElevenLabs Text to Sound Effects", - category="audio/partner/ElevenLabs", + category="partner/audio/ElevenLabs", description="Generate sound effects from text descriptions.", inputs=[ IO.String.Input( @@ -555,7 +555,7 @@ class ElevenLabsInstantVoiceClone(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsInstantVoiceClone", display_name="ElevenLabs Instant Voice Clone", - category="audio/partner/ElevenLabs", + category="partner/audio/ElevenLabs", description="Create a cloned voice from audio samples. " "Provide 1-8 audio recordings of the voice to clone.", inputs=[ @@ -658,7 +658,7 @@ class ElevenLabsSpeechToSpeech(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsSpeechToSpeech", display_name="ElevenLabs Speech to Speech", - category="audio/partner/ElevenLabs", + category="partner/audio/ElevenLabs", description="Transform speech from one voice to another while preserving the original content and emotion.", inputs=[ IO.Custom(ELEVENLABS_VOICE).Input( @@ -793,7 +793,7 @@ class ElevenLabsTextToDialogue(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsTextToDialogue", display_name="ElevenLabs Text to Dialogue", - category="audio/partner/ElevenLabs", + category="partner/audio/ElevenLabs", description="Generate multi-speaker dialogue from text. Each dialogue entry has its own text and voice.", inputs=[ IO.Float.Input( diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 3cfd541b2..e75ef3835 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -300,7 +300,7 @@ class GeminiNode(IO.ComfyNode): return IO.Schema( node_id="GeminiNode", display_name="Google Gemini", - category="text/partner/Gemini", + category="partner/text/Gemini", description="Generate text responses with Google's Gemini AI model. " "You can provide multiple types of inputs (text, images, audio, video) " "as context for generating more relevant and meaningful responses.", @@ -541,7 +541,7 @@ class GeminiInputFiles(IO.ComfyNode): return IO.Schema( node_id="GeminiInputFiles", display_name="Gemini Input Files", - category="text/partner/Gemini", + category="partner/text/Gemini", description="Loads and prepares input files to include as inputs for Gemini LLM nodes. " "The files will be read by the Gemini model when generating a response. " "The contents of the text file count toward the token limit. " @@ -598,7 +598,7 @@ class GeminiImage(IO.ComfyNode): return IO.Schema( node_id="GeminiImageNode", display_name="Nano Banana (Google Gemini Image)", - category="image/partner/Gemini", + category="partner/image/Gemini", description="Edit images synchronously via Google API.", inputs=[ IO.String.Input( @@ -731,7 +731,7 @@ class GeminiImage2(IO.ComfyNode): return IO.Schema( node_id="GeminiImage2Node", display_name="Nano Banana Pro (Google Gemini Image)", - category="image/partner/Gemini", + category="partner/image/Gemini", description="Generate or edit images synchronously via Google Vertex API.", inputs=[ IO.String.Input( @@ -869,7 +869,7 @@ class GeminiNanoBanana2(IO.ComfyNode): return IO.Schema( node_id="GeminiNanoBanana2", display_name="Nano Banana 2", - category="image/partner/Gemini", + category="partner/image/Gemini", description="Generate or edit images synchronously via Google Vertex API.", inputs=[ IO.String.Input( @@ -1085,7 +1085,7 @@ class GeminiNanoBanana2V2(IO.ComfyNode): return IO.Schema( node_id="GeminiNanoBanana2V2", display_name="Nano Banana 2", - category="image/partner/Gemini", + category="partner/image/Gemini", description="Generate or edit images synchronously via Google Vertex API.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py index a41da42f3..2ae529813 100644 --- a/comfy_api_nodes/nodes_grok.py +++ b/comfy_api_nodes/nodes_grok.py @@ -29,6 +29,11 @@ from comfy_api_nodes.util import ( ) +_GROK_VIDEO_MODEL_API_IDS = { + "grok-imagine-video-1.5": "grok-imagine-video-1.5-preview", +} + + def _extract_grok_price(response) -> float | None: if response.usage and response.usage.cost_in_usd_ticks is not None: return response.usage.cost_in_usd_ticks / 10_000_000_000 @@ -49,7 +54,7 @@ class GrokImageNode(IO.ComfyNode): return IO.Schema( node_id="GrokImageNode", display_name="Grok Image", - category="image/partner/Grok", + category="partner/image/Grok", description="Generate images using Grok based on a text prompt", inputs=[ IO.Combo.Input( @@ -223,7 +228,7 @@ class GrokImageEditNode(IO.ComfyNode): return IO.Schema( node_id="GrokImageEditNode", display_name="Grok Image Edit", - category="image/partner/Grok", + category="partner/image/Grok", description="Modify an existing image based on a text prompt", inputs=[ IO.Combo.Input( @@ -364,7 +369,7 @@ class GrokImageEditNodeV2(IO.ComfyNode): return IO.Schema( node_id="GrokImageEditNodeV2", display_name="Grok Image Edit", - category="image/partner/Grok", + category="partner/image/Grok", description="Modify an existing image based on a text prompt", inputs=[ IO.String.Input( @@ -501,10 +506,14 @@ class GrokVideoNode(IO.ComfyNode): return IO.Schema( node_id="GrokVideoNode", display_name="Grok Video", - category="video/partner/Grok", + category="partner/video/Grok", description="Generate video from a prompt or an image", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video"]), + IO.Combo.Input( + "model", + options=["grok-imagine-video", "grok-imagine-video-1.5"], + tooltip="grok-imagine-video-1.5 currently always requires an input image.", + ), IO.String.Input( "prompt", multiline=True, @@ -540,7 +549,11 @@ class GrokVideoNode(IO.ComfyNode): tooltip="Seed to determine if node should re-run; " "actual results are nondeterministic regardless of seed.", ), - IO.Image.Input("image", optional=True), + IO.Image.Input( + "image", + optional=True, + tooltip="Optional starting image for grok-imagine-video. Required for grok-imagine-video-1.5.", + ), ], outputs=[ IO.Video.Output(), @@ -552,12 +565,16 @@ class GrokVideoNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]), + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"], inputs=["image"]), expr=""" ( - $rate := widgets.resolution = "720p" ? 0.07 : 0.05; + $is15 := $contains(widgets.model, "1.5"); + $rate := $is15 + ? (widgets.resolution = "720p" ? 0.2002 : 0.1144) + : (widgets.resolution = "720p" ? 0.07 : 0.05); + $imgCost := $is15 ? 0.0143 : 0.002; $base := $rate * widgets.duration; - {"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base} + {"type":"usd","usd": inputs.image.connected ? $base + $imgCost : $base} ) """, ), @@ -574,6 +591,8 @@ class GrokVideoNode(IO.ComfyNode): seed: int, image: Input.Image | None = None, ) -> IO.NodeOutput: + if image is None and model == "grok-imagine-video-1.5": + raise ValueError(f"The '{model}' model requires an input image; connect one to the 'image' input.") image_url = None if image is not None: if get_number_of_images(image) != 1: @@ -584,7 +603,7 @@ class GrokVideoNode(IO.ComfyNode): cls, ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"), data=VideoGenerationRequest( - model=model, + model=_GROK_VIDEO_MODEL_API_IDS.get(model, model), image=image_url, prompt=prompt, resolution=resolution, @@ -599,7 +618,7 @@ class GrokVideoNode(IO.ComfyNode): ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), status_extractor=lambda r: r.status if r.status is not None else "complete", response_model=VideoStatusResponse, - price_extractor=_extract_grok_price, + price_extractor=_extract_grok_video_price if model == "grok-imagine-video-1.5" else _extract_grok_price, ) return IO.NodeOutput(await download_url_to_video_output(response.video.url)) @@ -611,7 +630,7 @@ class GrokVideoEditNode(IO.ComfyNode): return IO.Schema( node_id="GrokVideoEditNode", display_name="Grok Video Edit", - category="video/partner/Grok", + category="partner/video/Grok", description="Edit an existing video based on a text prompt.", inputs=[ IO.Combo.Input("model", options=["grok-imagine-video"]), @@ -689,7 +708,7 @@ class GrokVideoReferenceNode(IO.ComfyNode): return IO.Schema( node_id="GrokVideoReferenceNode", display_name="Grok Reference-to-Video", - category="video/partner/Grok", + category="partner/video/Grok", description="Generate video guided by reference images as style and content references.", inputs=[ IO.String.Input( @@ -822,7 +841,7 @@ class GrokVideoExtendNode(IO.ComfyNode): return IO.Schema( node_id="GrokVideoExtendNode", display_name="Grok Video Extend", - category="video/partner/Grok", + category="partner/video/Grok", description="Extend an existing video with a seamless continuation based on a text prompt.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_hitpaw.py b/comfy_api_nodes/nodes_hitpaw.py index 22e679c29..062d3cf1d 100644 --- a/comfy_api_nodes/nodes_hitpaw.py +++ b/comfy_api_nodes/nodes_hitpaw.py @@ -71,7 +71,7 @@ class HitPawGeneralImageEnhance(IO.ComfyNode): return IO.Schema( node_id="HitPawGeneralImageEnhance", display_name="HitPaw General Image Enhance", - category="image/partner/HitPaw", + category="partner/image/HitPaw", description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. " f"Maximum output: {MAX_MP_GENERATIVE} megapixels.", inputs=[ @@ -201,7 +201,7 @@ class HitPawVideoEnhance(IO.ComfyNode): return IO.Schema( node_id="HitPawVideoEnhance", display_name="HitPaw Video Enhance", - category="video/partner/HitPaw", + category="partner/video/HitPaw", description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. " "Prices shown are per second of video.", inputs=[ diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py index 826a3bd2d..fcd27b7fb 100644 --- a/comfy_api_nodes/nodes_hunyuan3d.py +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -123,7 +123,7 @@ class TencentTextToModelNode(IO.ComfyNode): return IO.Schema( node_id="TencentTextToModelNode", display_name="Hunyuan3D: Text to Model", - category="3d/partner/Tencent", + category="partner/3d/Tencent", essentials_category="3D", inputs=[ IO.Combo.Input( @@ -242,7 +242,7 @@ class TencentImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="TencentImageToModelNode", display_name="Hunyuan3D: Image(s) to Model", - category="3d/partner/Tencent", + category="partner/3d/Tencent", essentials_category="3D", inputs=[ IO.Combo.Input( @@ -415,7 +415,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode): return IO.Schema( node_id="TencentModelTo3DUVNode", display_name="Hunyuan3D: Model to UV", - category="3d/partner/Tencent", + category="partner/3d/Tencent", description="Perform UV unfolding on a 3D model to generate UV texture. " "Input model must have less than 30000 faces.", inputs=[ @@ -505,7 +505,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode): return IO.Schema( node_id="Tencent3DTextureEditNode", display_name="Hunyuan3D: 3D Texture Edit", - category="3d/partner/Tencent", + category="partner/3d/Tencent", description="After inputting the 3D model, perform 3D model texture redrawing.", inputs=[ IO.MultiType.Input( @@ -594,7 +594,7 @@ class Tencent3DPartNode(IO.ComfyNode): return IO.Schema( node_id="Tencent3DPartNode", display_name="Hunyuan3D: 3D Part", - category="3d/partner/Tencent", + category="partner/3d/Tencent", description="Automatically perform component identification and generation based on the model structure.", inputs=[ IO.MultiType.Input( @@ -666,7 +666,7 @@ class TencentSmartTopologyNode(IO.ComfyNode): return IO.Schema( node_id="TencentSmartTopologyNode", display_name="Hunyuan3D: Smart Topology", - category="3d/partner/Tencent", + category="partner/3d/Tencent", description="Perform smart retopology on a 3D model. " "Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.", inputs=[ diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index edd9b9435..8018c3902 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -234,7 +234,7 @@ class IdeogramV1(IO.ComfyNode): return IO.Schema( node_id="IdeogramV1", display_name="Ideogram V1", - category="image/partner/Ideogram", + category="partner/image/Ideogram", description="Generates images using the Ideogram V1 model.", inputs=[ IO.String.Input( @@ -360,7 +360,7 @@ class IdeogramV2(IO.ComfyNode): return IO.Schema( node_id="IdeogramV2", display_name="Ideogram V2", - category="image/partner/Ideogram", + category="partner/image/Ideogram", description="Generates images using the Ideogram V2 model.", inputs=[ IO.String.Input( @@ -526,7 +526,7 @@ class IdeogramV3(IO.ComfyNode): return IO.Schema( node_id="IdeogramV3", display_name="Ideogram V3", - category="image/partner/Ideogram", + category="partner/image/Ideogram", description="Generates images using the Ideogram V3 model. " "Supports both regular image generation from text prompts and image editing with mask.", inputs=[ diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 9925ec548..d11e42540 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -642,7 +642,7 @@ class KlingCameraControls(IO.ComfyNode): return IO.Schema( node_id="KlingCameraControls", display_name="Kling Camera Controls", - category="video/partner/Kling", + category="partner/video/Kling", description="Allows specifying configuration options for Kling Camera Controls and motion control effects.", inputs=[ IO.Combo.Input("camera_control_type", options=KlingCameraControlType), @@ -762,7 +762,7 @@ class KlingTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingTextToVideoNode", display_name="Kling Text to Video", - category="video/partner/Kling", + category="partner/video/Kling", description="Kling Text to Video Node", inputs=[ IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -849,7 +849,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProTextToVideoNode", display_name="Kling 3.0 Omni Text to Video", - category="video/partner/Kling", + category="partner/video/Kling", description="Use text prompts to generate videos with the latest Kling model.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), @@ -998,7 +998,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProFirstLastFrameNode", display_name="Kling 3.0 Omni First-Last-Frame to Video", - category="video/partner/Kling", + category="partner/video/Kling", description="Use a start frame, an optional end frame, or reference images with the latest Kling model.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), @@ -1205,7 +1205,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProImageToVideoNode", display_name="Kling 3.0 Omni Image to Video", - category="video/partner/Kling", + category="partner/video/Kling", description="Use up to 7 reference images to generate a video with the latest Kling model.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), @@ -1374,7 +1374,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProVideoToVideoNode", display_name="Kling 3.0 Omni Video to Video", - category="video/partner/Kling", + category="partner/video/Kling", description="Use a video and up to 4 reference images to generate a video with the latest Kling model.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), @@ -1485,7 +1485,7 @@ class OmniProEditVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProEditVideoNode", display_name="Kling 3.0 Omni Edit Video", - category="video/partner/Kling", + category="partner/video/Kling", essentials_category="Video Generation", description="Edit an existing video with the latest model from Kling.", inputs=[ @@ -1593,7 +1593,7 @@ class OmniProImageNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProImageNode", display_name="Kling 3.0 Omni Image", - category="image/partner/Kling", + category="partner/image/Kling", description="Create or edit images with the latest model from Kling.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]), @@ -1721,7 +1721,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode): return IO.Schema( node_id="KlingCameraControlT2VNode", display_name="Kling Text to Video (Camera Control)", - category="video/partner/Kling", + category="partner/video/Kling", description="Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text.", inputs=[ IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -1783,7 +1783,7 @@ class KlingImage2VideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingImage2VideoNode", display_name="Kling Image(First Frame) to Video", - category="video/partner/Kling", + category="partner/video/Kling", inputs=[ IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -1882,7 +1882,7 @@ class KlingCameraControlI2VNode(IO.ComfyNode): return IO.Schema( node_id="KlingCameraControlI2VNode", display_name="Kling Image to Video (Camera Control)", - category="video/partner/Kling", + category="partner/video/Kling", description="Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image.", inputs=[ IO.Image.Input( @@ -1953,7 +1953,7 @@ class KlingStartEndFrameNode(IO.ComfyNode): return IO.Schema( node_id="KlingStartEndFrameNode", display_name="Kling Start-End Frame to Video", - category="video/partner/Kling", + category="partner/video/Kling", description="Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last.", inputs=[ IO.Image.Input( @@ -2047,7 +2047,7 @@ class KlingVideoExtendNode(IO.ComfyNode): return IO.Schema( node_id="KlingVideoExtendNode", display_name="Kling Video Extend", - category="video/partner/Kling", + category="partner/video/Kling", description="Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes.", inputs=[ IO.String.Input( @@ -2128,7 +2128,7 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode): return IO.Schema( node_id="KlingDualCharacterVideoEffectNode", display_name="Kling Dual Character Video Effects", - category="video/partner/Kling", + category="partner/video/Kling", description="Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite.", inputs=[ IO.Image.Input("image_left", tooltip="Left side image"), @@ -2218,7 +2218,7 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode): return IO.Schema( node_id="KlingSingleImageVideoEffectNode", display_name="Kling Video Effects", - category="video/partner/Kling", + category="partner/video/Kling", description="Achieve different special effects when generating a video based on the effect_scene.", inputs=[ IO.Image.Input( @@ -2291,7 +2291,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingLipSyncAudioToVideoNode", display_name="Kling Lip Sync Video with Audio", - category="video/partner/Kling", + category="partner/video/Kling", essentials_category="Video Generation", description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.", inputs=[ @@ -2343,7 +2343,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingLipSyncTextToVideoNode", display_name="Kling Lip Sync Video with Text", - category="video/partner/Kling", + category="partner/video/Kling", description="Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.", inputs=[ IO.Video.Input("video"), @@ -2411,7 +2411,7 @@ class KlingVirtualTryOnNode(IO.ComfyNode): return IO.Schema( node_id="KlingVirtualTryOnNode", display_name="Kling Virtual Try On", - category="image/partner/Kling", + category="partner/image/Kling", description="Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background.", inputs=[ IO.Image.Input("human_image"), @@ -2478,7 +2478,7 @@ class KlingImageGenerationNode(IO.ComfyNode): return IO.Schema( node_id="KlingImageGenerationNode", display_name="Kling 3.0 Image", - category="image/partner/Kling", + category="partner/image/Kling", description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.", inputs=[ IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -2615,7 +2615,7 @@ class TextToVideoWithAudio(IO.ComfyNode): return IO.Schema( node_id="KlingTextToVideoWithAudio", display_name="Kling 2.6 Text to Video with Audio", - category="video/partner/Kling", + category="partner/video/Kling", inputs=[ IO.Combo.Input("model_name", options=["kling-v2-6"]), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."), @@ -2683,7 +2683,7 @@ class ImageToVideoWithAudio(IO.ComfyNode): return IO.Schema( node_id="KlingImageToVideoWithAudio", display_name="Kling 2.6 Image(First Frame) to Video with Audio", - category="video/partner/Kling", + category="partner/video/Kling", inputs=[ IO.Combo.Input("model_name", options=["kling-v2-6"]), IO.Image.Input("start_frame"), @@ -2753,7 +2753,7 @@ class MotionControl(IO.ComfyNode): return IO.Schema( node_id="KlingMotionControl", display_name="Kling Motion Control", - category="video/partner/Kling", + category="partner/video/Kling", inputs=[ IO.String.Input("prompt", multiline=True), IO.Image.Input("reference_image"), @@ -2854,7 +2854,7 @@ class KlingVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingVideoNode", display_name="Kling 3.0 Video", - category="video/partner/Kling", + category="partner/video/Kling", description="Generate videos with Kling V3. " "Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.", inputs=[ @@ -3077,7 +3077,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="KlingFirstLastFrameNode", display_name="Kling 3.0 First-Last-Frame to Video", - category="video/partner/Kling", + category="partner/video/Kling", description="Generate videos with Kling V3 using first and last frames.", inputs=[ IO.String.Input("prompt", multiline=True, default=""), @@ -3202,7 +3202,7 @@ class KlingAvatarNode(IO.ComfyNode): return IO.Schema( node_id="KlingAvatarNode", display_name="Kling Avatar 2.0", - category="video/partner/Kling", + category="partner/video/Kling", description="Generate broadcast-style digital human videos from a single photo and an audio file.", inputs=[ IO.Image.Input( diff --git a/comfy_api_nodes/nodes_krea.py b/comfy_api_nodes/nodes_krea.py index be04a272b..34369f05f 100644 --- a/comfy_api_nodes/nodes_krea.py +++ b/comfy_api_nodes/nodes_krea.py @@ -106,7 +106,7 @@ class Krea2ImageNode(IO.ComfyNode): return IO.Schema( node_id="Krea2ImageNode", display_name="Krea 2 Image", - category="image/partner/Krea", + category="partner/image/Krea", description=( "Generate images via Krea 2 — pick Medium (expressive illustrations) or " "Large (expressive photorealism). Supports an optional moodboard and up " @@ -229,7 +229,7 @@ class Krea2StyleReferenceNode(IO.ComfyNode): return IO.Schema( node_id="Krea2StyleReferenceNode", display_name="Krea 2 Style Reference", - category="image/partner/Krea", + category="partner/image/Krea", description=( "Add an image style reference to a Krea 2 generation. Chain multiple Krea 2 " "Style Reference nodes (max 10) and feed the final `style_reference` output " diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py index 01791d354..878e04b4e 100644 --- a/comfy_api_nodes/nodes_ltxv.py +++ b/comfy_api_nodes/nodes_ltxv.py @@ -50,7 +50,7 @@ class TextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="LtxvApiTextToVideo", display_name="LTXV Text To Video", - category="video/partner/LTXV", + category="partner/video/LTXV", description="Professional-quality videos with customizable duration and resolution.", inputs=[ IO.Combo.Input("model", options=list(MODELS_MAP.keys())), @@ -127,7 +127,7 @@ class ImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="LtxvApiImageToVideo", display_name="LTXV Image To Video", - category="video/partner/LTXV", + category="partner/video/LTXV", description="Professional-quality videos with customizable duration and resolution based on start image.", inputs=[ IO.Image.Input("image", tooltip="First frame to be used for the video."), diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 08ae9904c..0d31ac77e 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -46,7 +46,7 @@ class LumaReferenceNode(IO.ComfyNode): return IO.Schema( node_id="LumaReferenceNode", display_name="Luma Reference", - category="image/partner/Luma", + category="partner/image/Luma", description="Holds an image and weight for use with Luma Generate Image node.", inputs=[ IO.Image.Input( @@ -85,7 +85,7 @@ class LumaConceptsNode(IO.ComfyNode): return IO.Schema( node_id="LumaConceptsNode", display_name="Luma Concepts", - category="video/partner/Luma", + category="partner/video/Luma", description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.", inputs=[ IO.Combo.Input( @@ -134,7 +134,7 @@ class LumaImageGenerationNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageNode", display_name="Luma Text to Image", - category="image/partner/Luma", + category="partner/image/Luma", description="Generates images synchronously based on prompt and aspect ratio.", inputs=[ IO.String.Input( @@ -278,7 +278,7 @@ class LumaImageModifyNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageModifyNode", display_name="Luma Image to Image", - category="image/partner/Luma", + category="partner/image/Luma", description="Modifies images synchronously based on prompt and aspect ratio.", inputs=[ IO.Image.Input( @@ -371,7 +371,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): return IO.Schema( node_id="LumaVideoNode", display_name="Luma Text to Video", - category="video/partner/Luma", + category="partner/video/Luma", description="Generates videos synchronously based on prompt and output_size.", inputs=[ IO.String.Input( @@ -472,7 +472,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageToVideoNode", display_name="Luma Image to Video", - category="video/partner/Luma", + category="partner/video/Luma", description="Generates videos synchronously based on prompt, input images, and output_size.", inputs=[ IO.String.Input( @@ -724,7 +724,7 @@ class LumaImageNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageNode2", display_name="Luma UNI-1 Image", - category="image/partner/Luma", + category="partner/image/Luma", description="Generate images from text using the Luma UNI-1 model.", inputs=[ IO.String.Input( @@ -853,7 +853,7 @@ class LumaImageEditNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageEditNode2", display_name="Luma UNI-1 Image Edit", - category="image/partner/Luma", + category="partner/image/Luma", description="Edit an existing image with a text prompt using the Luma UNI-1 model.", inputs=[ IO.Image.Input( diff --git a/comfy_api_nodes/nodes_magnific.py b/comfy_api_nodes/nodes_magnific.py index a6aeb194a..4ce4735df 100644 --- a/comfy_api_nodes/nodes_magnific.py +++ b/comfy_api_nodes/nodes_magnific.py @@ -61,7 +61,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode): return IO.Schema( node_id="MagnificImageUpscalerCreativeNode", display_name="Magnific Image Upscale (Creative)", - category="image/partner/Magnific", + category="partner/image/Magnific", description="Prompt‑guided enhancement, stylization, and 2x/4x/8x/16x upscaling. " "Maximum output: 25.3 megapixels.", inputs=[ @@ -240,7 +240,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): return IO.Schema( node_id="MagnificImageUpscalerPreciseV2Node", display_name="Magnific Image Upscale (Precise V2)", - category="image/partner/Magnific", + category="partner/image/Magnific", description="High-fidelity upscaling with fine control over sharpness, grain, and detail. " "Maximum output: 10060×10060 pixels.", inputs=[ @@ -400,7 +400,7 @@ class MagnificImageStyleTransferNode(IO.ComfyNode): return IO.Schema( node_id="MagnificImageStyleTransferNode", display_name="Magnific Image Style Transfer", - category="image/partner/Magnific", + category="partner/image/Magnific", description="Transfer the style from a reference image to your input image.", inputs=[ IO.Image.Input("image", tooltip="The image to apply style transfer to."), @@ -549,7 +549,7 @@ class MagnificImageRelightNode(IO.ComfyNode): return IO.Schema( node_id="MagnificImageRelightNode", display_name="Magnific Image Relight", - category="image/partner/Magnific", + category="partner/image/Magnific", description="Relight an image with lighting adjustments and optional reference-based light transfer.", inputs=[ IO.Image.Input("image", tooltip="The image to relight."), @@ -789,7 +789,7 @@ class MagnificImageSkinEnhancerNode(IO.ComfyNode): return IO.Schema( node_id="MagnificImageSkinEnhancerNode", display_name="Magnific Image Skin Enhancer", - category="image/partner/Magnific", + category="partner/image/Magnific", description="Skin enhancement for portraits with multiple processing modes.", inputs=[ IO.Image.Input("image", tooltip="The portrait image to enhance."), diff --git a/comfy_api_nodes/nodes_meshy.py b/comfy_api_nodes/nodes_meshy.py index 4fb670404..3a24f1095 100644 --- a/comfy_api_nodes/nodes_meshy.py +++ b/comfy_api_nodes/nodes_meshy.py @@ -33,7 +33,7 @@ class MeshyTextToModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyTextToModelNode", display_name="Meshy: Text to Model", - category="3d/partner/Meshy", + category="partner/3d/Meshy", inputs=[ IO.Combo.Input("model", options=["latest"]), IO.String.Input("prompt", multiline=True, default=""), @@ -145,7 +145,7 @@ class MeshyRefineNode(IO.ComfyNode): return IO.Schema( node_id="MeshyRefineNode", display_name="Meshy: Refine Draft Model", - category="3d/partner/Meshy", + category="partner/3d/Meshy", description="Refine a previously created draft model.", inputs=[ IO.Combo.Input("model", options=["latest"]), @@ -240,7 +240,7 @@ class MeshyImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyImageToModelNode", display_name="Meshy: Image to Model", - category="3d/partner/Meshy", + category="partner/3d/Meshy", inputs=[ IO.Combo.Input("model", options=["latest"]), IO.Image.Input("image"), @@ -405,7 +405,7 @@ class MeshyMultiImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyMultiImageToModelNode", display_name="Meshy: Multi-Image to Model", - category="3d/partner/Meshy", + category="partner/3d/Meshy", inputs=[ IO.Combo.Input("model", options=["latest"]), IO.Autogrow.Input( @@ -575,7 +575,7 @@ class MeshyRigModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyRigModelNode", display_name="Meshy: Rig Model", - category="3d/partner/Meshy", + category="partner/3d/Meshy", description="Provides a rigged character in standard formats. " "Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, " "or humanoid assets with unclear limb and body structure.", @@ -656,7 +656,7 @@ class MeshyAnimateModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyAnimateModelNode", display_name="Meshy: Animate Model", - category="3d/partner/Meshy", + category="partner/3d/Meshy", description="Apply a specific animation action to a previously rigged character.", inputs=[ IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"), @@ -722,7 +722,7 @@ class MeshyTextureNode(IO.ComfyNode): return IO.Schema( node_id="MeshyTextureNode", display_name="Meshy: Texture Model", - category="3d/partner/Meshy", + category="partner/3d/Meshy", inputs=[ IO.Combo.Input("model", options=["latest"]), IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 338584148..6250af146 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -101,7 +101,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="MinimaxTextToVideoNode", display_name="MiniMax Text to Video", - category="video/partner/MiniMax", + category="partner/video/MiniMax", description="Generates videos synchronously based on a prompt, and optional parameters.", inputs=[ IO.String.Input( @@ -163,7 +163,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="MinimaxImageToVideoNode", display_name="MiniMax Image to Video", - category="video/partner/MiniMax", + category="partner/video/MiniMax", description="Generates videos synchronously based on an image and prompt, and optional parameters.", inputs=[ IO.Image.Input( @@ -230,7 +230,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode): return IO.Schema( node_id="MinimaxSubjectToVideoNode", display_name="MiniMax Subject to Video", - category="video/partner/MiniMax", + category="partner/video/MiniMax", description="Generates videos synchronously based on an image and prompt, and optional parameters.", inputs=[ IO.Image.Input( @@ -294,7 +294,7 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): return IO.Schema( node_id="MinimaxHailuoVideoNode", display_name="MiniMax Hailuo Video", - category="video/partner/MiniMax", + category="partner/video/MiniMax", description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index 48c739dfe..0fe5fb9d0 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -99,7 +99,7 @@ class OpenAIDalle2(IO.ComfyNode): return IO.Schema( node_id="OpenAIDalle2", display_name="OpenAI DALL·E 2", - category="image/partner/OpenAI", + category="partner/image/OpenAI", description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.", inputs=[ IO.String.Input( @@ -249,7 +249,7 @@ class OpenAIDalle3(IO.ComfyNode): return IO.Schema( node_id="OpenAIDalle3", display_name="OpenAI DALL·E 3", - category="image/partner/OpenAI", + category="partner/image/OpenAI", description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.", inputs=[ IO.String.Input( @@ -371,7 +371,7 @@ class OpenAIGPTImage1(IO.ComfyNode): return IO.Schema( node_id="OpenAIGPTImage1", display_name="OpenAI GPT Image 2", - category="image/partner/OpenAI", + category="partner/image/OpenAI", description="Generates images synchronously via OpenAI's GPT Image endpoint.", is_deprecated=True, inputs=[ @@ -695,7 +695,7 @@ class OpenAIGPTImageNodeV2(IO.ComfyNode): return IO.Schema( node_id="OpenAIGPTImageNodeV2", display_name="OpenAI GPT Image 2", - category="image/partner/OpenAI", + category="partner/image/OpenAI", description="Generates images via OpenAI's GPT Image endpoint.", inputs=[ IO.String.Input( @@ -962,7 +962,7 @@ class OpenAIChatNode(IO.ComfyNode): return IO.Schema( node_id="OpenAIChatNode", display_name="OpenAI ChatGPT", - category="text/partner/OpenAI", + category="partner/text/OpenAI", essentials_category="Text Generation", description="Generate text responses from an OpenAI model.", inputs=[ @@ -1201,7 +1201,7 @@ class OpenAIInputFiles(IO.ComfyNode): return IO.Schema( node_id="OpenAIInputFiles", display_name="OpenAI ChatGPT Input Files", - category="text/partner/OpenAI", + category="partner/text/OpenAI", description="Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes.", inputs=[ IO.Combo.Input( @@ -1248,7 +1248,7 @@ class OpenAIChatConfig(IO.ComfyNode): return IO.Schema( node_id="OpenAIChatConfig", display_name="OpenAI ChatGPT Advanced Options", - category="text/partner/OpenAI", + category="partner/text/OpenAI", description="Allows specifying advanced configuration options for the OpenAI Chat Nodes.", inputs=[ IO.Combo.Input( diff --git a/comfy_api_nodes/nodes_openrouter.py b/comfy_api_nodes/nodes_openrouter.py index d2ebbef0d..ba98133f0 100644 --- a/comfy_api_nodes/nodes_openrouter.py +++ b/comfy_api_nodes/nodes_openrouter.py @@ -265,7 +265,7 @@ class OpenRouterLLMNode(IO.ComfyNode): return IO.Schema( node_id="OpenRouterLLMNode", display_name="OpenRouter LLM", - category="text/partner/OpenRouter", + category="partner/text/OpenRouter", essentials_category="Text Generation", description=( "Generate text responses through OpenRouter. Routes to a curated set of popular " diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 3861cfedd..4c8b723b9 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -53,7 +53,7 @@ class PixverseTemplateNode(IO.ComfyNode): return IO.Schema( node_id="PixverseTemplateNode", display_name="PixVerse Template", - category="video/partner/PixVerse", + category="partner/video/PixVerse", inputs=[ IO.Combo.Input("template", options=list(pixverse_templates.keys())), ], @@ -74,7 +74,7 @@ class PixverseTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="PixverseTextToVideoNode", display_name="PixVerse Text to Video", - category="video/partner/PixVerse", + category="partner/video/PixVerse", description="Generates videos based on prompt and output_size.", inputs=[ IO.String.Input( @@ -192,7 +192,7 @@ class PixverseImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="PixverseImageToVideoNode", display_name="PixVerse Image to Video", - category="video/partner/PixVerse", + category="partner/video/PixVerse", description="Generates videos based on prompt and output_size.", inputs=[ IO.Image.Input("image"), @@ -310,7 +310,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode): return IO.Schema( node_id="PixverseTransitionVideoNode", display_name="PixVerse Transition Video", - category="video/partner/PixVerse", + category="partner/video/PixVerse", description="Generates videos based on prompt and output_size.", inputs=[ IO.Image.Input("first_frame"), diff --git a/comfy_api_nodes/nodes_quiver.py b/comfy_api_nodes/nodes_quiver.py index ad045a7ef..34929fa0c 100644 --- a/comfy_api_nodes/nodes_quiver.py +++ b/comfy_api_nodes/nodes_quiver.py @@ -62,7 +62,7 @@ class QuiverTextToSVGNode(IO.ComfyNode): return IO.Schema( node_id="QuiverTextToSVGNode", display_name="Quiver Text to SVG", - category="image/partner/Quiver", + category="partner/image/Quiver", description="Generate an SVG from a text prompt using Quiver AI.", inputs=[ IO.String.Input( @@ -177,7 +177,7 @@ class QuiverImageToSVGNode(IO.ComfyNode): return IO.Schema( node_id="QuiverImageToSVGNode", display_name="Quiver Image to SVG", - category="image/partner/Quiver", + category="partner/image/Quiver", description="Vectorize a raster image into SVG using Quiver AI.", inputs=[ IO.Image.Input( diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 07387821d..c44942f50 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -178,7 +178,7 @@ class RecraftColorRGBNode(IO.ComfyNode): return IO.Schema( node_id="RecraftColorRGB", display_name="Recraft Color RGB", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Create Recraft Color by choosing specific RGB values.", inputs=[ IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."), @@ -204,7 +204,7 @@ class RecraftControlsNode(IO.ComfyNode): return IO.Schema( node_id="RecraftControls", display_name="Recraft Controls", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Create Recraft Controls for customizing Recraft generation.", inputs=[ IO.Custom(RecraftIO.COLOR).Input("colors", optional=True), @@ -228,7 +228,7 @@ class RecraftStyleV3RealisticImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftStyleV3RealisticImage", display_name="Recraft Style - Realistic Image", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Select realistic_image style and optional substyle.", inputs=[ IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), @@ -253,7 +253,7 @@ class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode): return IO.Schema( node_id="RecraftStyleV3DigitalIllustration", display_name="Recraft Style - Digital Illustration", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Select realistic_image style and optional substyle.", inputs=[ IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), @@ -272,7 +272,7 @@ class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode): return IO.Schema( node_id="RecraftStyleV3VectorIllustrationNode", display_name="Recraft Style - Realistic Image", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Select realistic_image style and optional substyle.", inputs=[ IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), @@ -291,7 +291,7 @@ class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode): return IO.Schema( node_id="RecraftStyleV3LogoRaster", display_name="Recraft Style - Logo Raster", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Select realistic_image style and optional substyle.", inputs=[ IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)), @@ -308,7 +308,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode): return IO.Schema( node_id="RecraftStyleV3InfiniteStyleLibrary", display_name="Recraft Style - Infinite Style Library", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.", inputs=[ IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."), @@ -331,7 +331,7 @@ class RecraftCreateStyleNode(IO.ComfyNode): return IO.Schema( node_id="RecraftCreateStyleNode", display_name="Recraft Create Style", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Create a custom style from reference images. " "Upload 1-5 images to use as style references. " "Total size of all images is limited to 5 MB.", @@ -400,7 +400,7 @@ class RecraftTextToImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftTextToImageNode", display_name="Recraft Text to Image", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Generates images synchronously based on prompt and resolution.", inputs=[ IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), @@ -512,7 +512,7 @@ class RecraftImageToImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftImageToImageNode", display_name="Recraft Image to Image", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Modify image based on prompt and strength.", inputs=[ IO.Image.Input("image"), @@ -630,7 +630,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode): return IO.Schema( node_id="RecraftImageInpaintingNode", display_name="Recraft Image Inpainting", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Modify image based on prompt and mask.", inputs=[ IO.Image.Input("image"), @@ -732,7 +732,7 @@ class RecraftTextToVectorNode(IO.ComfyNode): return IO.Schema( node_id="RecraftTextToVectorNode", display_name="Recraft Text to Vector", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Generates SVG synchronously based on prompt and resolution.", inputs=[ IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True), @@ -832,7 +832,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftVectorizeImageNode", display_name="Recraft Vectorize Image", - category="image/partner/Recraft", + category="partner/image/Recraft", essentials_category="Image Tools", description="Generates SVG synchronously from an input image.", inputs=[ @@ -876,7 +876,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode): return IO.Schema( node_id="RecraftReplaceBackgroundNode", display_name="Recraft Replace Background", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Replace background on image, based on provided prompt.", inputs=[ IO.Image.Input("image"), @@ -963,7 +963,7 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode): return IO.Schema( node_id="RecraftRemoveBackgroundNode", display_name="Recraft Remove Background", - category="image/partner/Recraft", + category="partner/image/Recraft", essentials_category="Image Tools", description="Remove background from image, and return processed image and mask.", inputs=[ @@ -1012,7 +1012,7 @@ class RecraftCrispUpscaleNode(IO.ComfyNode): return IO.Schema( node_id="RecraftCrispUpscaleNode", display_name="Recraft Crisp Upscale Image", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Upscale image synchronously.\n" "Enhances a given raster image using ‘crisp upscale’ tool, " "increasing image resolution, making the image sharper and cleaner.", @@ -1058,7 +1058,7 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode): return IO.Schema( node_id="RecraftCreativeUpscaleNode", display_name="Recraft Creative Upscale Image", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Upscale image synchronously.\n" "Enhances a given raster image using ‘creative upscale’ tool, " "boosting resolution with a focus on refining small details and faces.", @@ -1086,7 +1086,7 @@ class RecraftV4TextToImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftV4TextToImageNode", display_name="Recraft V4 Text to Image", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Generates images using Recraft V4 or V4 Pro models.", inputs=[ IO.String.Input( @@ -1210,7 +1210,7 @@ class RecraftV4TextToVectorNode(IO.ComfyNode): return IO.Schema( node_id="RecraftV4TextToVectorNode", display_name="Recraft V4 Text to Vector", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Generates SVG using Recraft V4 or V4 Pro models.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_reve.py b/comfy_api_nodes/nodes_reve.py index 2b15eadd7..177349a8b 100644 --- a/comfy_api_nodes/nodes_reve.py +++ b/comfy_api_nodes/nodes_reve.py @@ -109,7 +109,7 @@ class ReveImageCreateNode(IO.ComfyNode): return IO.Schema( node_id="ReveImageCreateNode", display_name="Reve Image Create", - category="image/partner/Reve", + category="partner/image/Reve", description="Generate images from text descriptions using Reve.", inputs=[ IO.String.Input( @@ -200,7 +200,7 @@ class ReveImageEditNode(IO.ComfyNode): return IO.Schema( node_id="ReveImageEditNode", display_name="Reve Image Edit", - category="image/partner/Reve", + category="partner/image/Reve", description="Edit images using natural language instructions with Reve.", inputs=[ IO.Image.Input("image", tooltip="The image to edit."), @@ -300,7 +300,7 @@ class ReveImageRemixNode(IO.ComfyNode): return IO.Schema( node_id="ReveImageRemixNode", display_name="Reve Image Remix", - category="image/partner/Reve", + category="partner/image/Reve", description="Combine reference images with text prompts to create new images using Reve.", inputs=[ IO.Autogrow.Input( diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index e14955661..0375a2123 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -230,7 +230,7 @@ class Rodin3D_Regular(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Regular", display_name="Rodin 3D Generate - Regular Generate", - category="3d/partner/Rodin", + category="partner/3d/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -289,7 +289,7 @@ class Rodin3D_Detail(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Detail", display_name="Rodin 3D Generate - Detail Generate", - category="3d/partner/Rodin", + category="partner/3d/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -348,7 +348,7 @@ class Rodin3D_Smooth(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Smooth", display_name="Rodin 3D Generate - Smooth Generate", - category="3d/partner/Rodin", + category="partner/3d/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -406,7 +406,7 @@ class Rodin3D_Sketch(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Sketch", display_name="Rodin 3D Generate - Sketch Generate", - category="3d/partner/Rodin", + category="partner/3d/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -468,7 +468,7 @@ class Rodin3D_Gen2(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Gen2", display_name="Rodin 3D Generate - Gen-2 Generate", - category="3d/partner/Rodin", + category="partner/3d/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -941,7 +941,7 @@ class Rodin3D_Gen25_Image(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Gen25_Image", display_name="Rodin 3D Gen-2.5 - Image to 3D", - category="3d/partner/Rodin", + category="partner/3d/Rodin", description=( "Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. " "Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost." @@ -1035,7 +1035,7 @@ class Rodin3D_Gen25_Text(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Gen25_Text", display_name="Rodin 3D Gen-2.5 - Text to 3D", - category="3d/partner/Rodin", + category="partner/3d/Rodin", description=( "Generate a 3D model from a text prompt via Rodin Gen-2.5. " "Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost." diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index 7357c733e..b9c5c81a1 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -140,7 +140,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): return IO.Schema( node_id="RunwayImageToVideoNodeGen3a", display_name="Runway Image to Video (Gen3a Turbo)", - category="video/partner/Runway", + category="partner/video/Runway", description="Generate a video from a single starting frame using Gen3a Turbo model. " "Before diving in, review these best practices to ensure that " "your input selections will set your generation up for success: " @@ -234,7 +234,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): return IO.Schema( node_id="RunwayImageToVideoNodeGen4", display_name="Runway Image to Video (Gen4 Turbo)", - category="video/partner/Runway", + category="partner/video/Runway", description="Generate a video from a single starting frame using Gen4 Turbo model. " "Before diving in, review these best practices to ensure that " "your input selections will set your generation up for success: " @@ -329,7 +329,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="RunwayFirstLastFrameNode", display_name="Runway First-Last-Frame to Video", - category="video/partner/Runway", + category="partner/video/Runway", description="Upload first and last keyframes, draft a prompt, and generate a video. " "More complex transitions, such as cases where the Last frame is completely different " "from the First frame, may benefit from the longer 10s duration. " @@ -440,7 +440,7 @@ class RunwayTextToImageNode(IO.ComfyNode): return IO.Schema( node_id="RunwayTextToImageNode", display_name="Runway Text to Image", - category="image/partner/Runway", + category="partner/image/Runway", description="Generate an image from a text prompt using Runway's Gen 4 model. " "You can also include reference image to guide the generation.", inputs=[ diff --git a/comfy_api_nodes/nodes_sonilo.py b/comfy_api_nodes/nodes_sonilo.py index bc31a0074..9ce896ed0 100644 --- a/comfy_api_nodes/nodes_sonilo.py +++ b/comfy_api_nodes/nodes_sonilo.py @@ -34,7 +34,7 @@ class SoniloVideoToMusic(IO.ComfyNode): return IO.Schema( node_id="SoniloVideoToMusic", display_name="Sonilo Video to Music", - category="audio/partner/Sonilo", + category="partner/audio/Sonilo", description="Generate music from video content using Sonilo's AI model. " "Analyzes the video and creates matching music.", inputs=[ @@ -99,7 +99,7 @@ class SoniloTextToMusic(IO.ComfyNode): return IO.Schema( node_id="SoniloTextToMusic", display_name="Sonilo Text to Music", - category="audio/partner/Sonilo", + category="partner/audio/Sonilo", description="Generate music from a text prompt using Sonilo's AI model. " "Leave duration at 0 to let the model infer it from the prompt.", inputs=[ diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index 83cfca495..4ff1d649f 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -34,7 +34,7 @@ class OpenAIVideoSora2(IO.ComfyNode): return IO.Schema( node_id="OpenAIVideoSora2", display_name="OpenAI Sora - Video (DEPRECATED)", - category="video/partner/Sora", + category="partner/video/Sora", description=( "OpenAI video and audio generation.\n\n" "DEPRECATION NOTICE: OpenAI will stop serving the Sora v2 API in September 2026. " diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index a1753d647..9eaba173b 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -62,7 +62,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode): return IO.Schema( node_id="StabilityStableImageUltraNode", display_name="Stability AI Stable Image Ultra", - category="image/partner/Stability AI", + category="partner/image/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.String.Input( @@ -197,7 +197,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode): return IO.Schema( node_id="StabilityStableImageSD_3_5Node", display_name="Stability AI Stable Diffusion 3.5 Image", - category="image/partner/Stability AI", + category="partner/image/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.String.Input( @@ -354,7 +354,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): return IO.Schema( node_id="StabilityUpscaleConservativeNode", display_name="Stability AI Upscale Conservative", - category="image/partner/Stability AI", + category="partner/image/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("image"), @@ -457,7 +457,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): return IO.Schema( node_id="StabilityUpscaleCreativeNode", display_name="Stability AI Upscale Creative", - category="image/partner/Stability AI", + category="partner/image/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("image"), @@ -578,7 +578,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode): return IO.Schema( node_id="StabilityUpscaleFastNode", display_name="Stability AI Upscale Fast", - category="image/partner/Stability AI", + category="partner/image/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("image"), @@ -630,7 +630,7 @@ class StabilityTextToAudio(IO.ComfyNode): return IO.Schema( node_id="StabilityTextToAudio", display_name="Stability AI Text To Audio", - category="audio/partner/Stability AI", + category="partner/audio/Stability AI", essentials_category="Audio", description=cleandoc(cls.__doc__ or ""), inputs=[ @@ -708,7 +708,7 @@ class StabilityAudioToAudio(IO.ComfyNode): return IO.Schema( node_id="StabilityAudioToAudio", display_name="Stability AI Audio To Audio", - category="audio/partner/Stability AI", + category="partner/audio/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Combo.Input( @@ -802,7 +802,7 @@ class StabilityAudioInpaint(IO.ComfyNode): return IO.Schema( node_id="StabilityAudioInpaint", display_name="Stability AI Audio Inpaint", - category="audio/partner/Stability AI", + category="partner/audio/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Combo.Input( diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index d0906ee44..f7ef4cbf6 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -52,7 +52,7 @@ class TopazImageEnhance(IO.ComfyNode): return IO.Schema( node_id="TopazImageEnhance", display_name="Topaz Image Enhance", - category="image/partner/Topaz", + category="partner/image/Topaz", description="Industry-standard upscaling and image enhancement.", inputs=[ IO.Combo.Input("model", options=["Reimagine"]), @@ -235,7 +235,7 @@ class TopazVideoEnhance(IO.ComfyNode): return IO.Schema( node_id="TopazVideoEnhance", display_name="Topaz Video Enhance (Legacy)", - category="video/partner/Topaz", + category="partner/video/Topaz", description="Breathe new life into video with powerful upscaling and recovery technology.", inputs=[ IO.Video.Input("video"), @@ -475,7 +475,7 @@ class TopazVideoEnhanceV2(IO.ComfyNode): return IO.Schema( node_id="TopazVideoEnhanceV2", display_name="Topaz Video Enhance", - category="video/partner/Topaz", + category="partner/video/Topaz", description="Breathe new life into video with powerful upscaling and recovery technology.", inputs=[ IO.Video.Input("video"), diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index 4820e26c1..a3f2cb053 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -83,7 +83,7 @@ class TripoTextToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoTextToModelNode", display_name="Tripo: Text to Model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", inputs=[ IO.String.Input("prompt", multiline=True), IO.String.Input("negative_prompt", multiline=True, optional=True), @@ -210,7 +210,7 @@ class TripoImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoImageToModelNode", display_name="Tripo: Image to Model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", inputs=[ IO.Image.Input("image"), IO.Combo.Input( @@ -358,7 +358,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoMultiviewToModelNode", display_name="Tripo: Multiview to Model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", inputs=[ IO.Image.Input("image"), IO.Image.Input("image_left", optional=True), @@ -518,7 +518,7 @@ class TripoTextureNode(IO.ComfyNode): return IO.Schema( node_id="TripoTextureNode", display_name="Tripo: Texture model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", inputs=[ IO.Custom("MODEL_TASK_ID").Input("model_task_id"), IO.Boolean.Input("texture", default=True, optional=True), @@ -595,7 +595,7 @@ class TripoRefineNode(IO.ComfyNode): return IO.Schema( node_id="TripoRefineNode", display_name="Tripo: Refine Draft model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", description="Refine a draft model created by v1.4 Tripo models only.", inputs=[ IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"), @@ -635,7 +635,7 @@ class TripoRigNode(IO.ComfyNode): return IO.Schema( node_id="TripoRigNode", display_name="Tripo: Rig model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")], outputs=[ IO.String.Output(display_name="model_file"), # for backward compatibility only @@ -672,7 +672,7 @@ class TripoRetargetNode(IO.ComfyNode): return IO.Schema( node_id="TripoRetargetNode", display_name="Tripo: Retarget rigged model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", inputs=[ IO.Custom("RIG_TASK_ID").Input("original_model_task_id"), IO.Combo.Input( @@ -737,7 +737,7 @@ class TripoConversionNode(IO.ComfyNode): return IO.Schema( node_id="TripoConversionNode", display_name="Tripo: Convert model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", inputs=[ IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"), IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]), @@ -1051,7 +1051,7 @@ class TripoP1TextToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoP1TextToModelNode", display_name="Tripo P1: Text to Model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", description="Tripo P1 text-to-3D. Optimized for low-poly, game-ready meshes with stable topology.", inputs=[ IO.String.Input("prompt", multiline=True, tooltip="Up to 1024 characters."), @@ -1122,7 +1122,7 @@ class TripoP1ImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoP1ImageToModelNode", display_name="Tripo P1: Image to Model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", description="Tripo P1 image-to-3D. Optimized for low-poly, game-ready meshes.", inputs=[ IO.Image.Input("image"), @@ -1202,7 +1202,7 @@ class TripoP1MultiviewToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoP1MultiviewToModelNode", display_name="Tripo P1: Multiview to Model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", description="Tripo P1 multiview-to-3D from 2-4 reference images in [front, left, back, right] order. " "Front is required; any combination of the other three may be omitted.", inputs=[ diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 068862397..ed34e928b 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -45,7 +45,7 @@ class VeoVideoGenerationNode(IO.ComfyNode): return IO.Schema( node_id="VeoVideoGenerationNode", display_name="Google Veo 2 Video Generation", - category="video/partner/Veo", + category="partner/video/Veo", description="Generates videos from text prompts using Google's Veo 2 API", inputs=[ IO.String.Input( @@ -256,7 +256,7 @@ class Veo3VideoGenerationNode(IO.ComfyNode): return IO.Schema( node_id="Veo3VideoGenerationNode", display_name="Google Veo 3 Video Generation", - category="video/partner/Veo", + category="partner/video/Veo", description="Generates videos from text prompts using Google's Veo 3 API", inputs=[ IO.String.Input( @@ -468,7 +468,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="Veo3FirstLastFrameNode", display_name="Google Veo 3 First-Last-Frame to Video", - category="video/partner/Veo", + category="partner/video/Veo", description="Generate video using prompt and first and last frames.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 16f6113de..8c5a43f5b 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -71,7 +71,7 @@ class ViduTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduTextToVideoNode", display_name="Vidu Text To Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate video from a text prompt", inputs=[ IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), @@ -169,7 +169,7 @@ class ViduImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduImageToVideoNode", display_name="Vidu Image To Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate video from image and optional prompt", inputs=[ IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), @@ -273,7 +273,7 @@ class ViduReferenceVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduReferenceVideoNode", display_name="Vidu Reference To Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate video from multiple images and a prompt", inputs=[ IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), @@ -388,7 +388,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduStartEndToVideoNode", display_name="Vidu Start End To Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate a video from start and end frames and a prompt", inputs=[ IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), @@ -492,7 +492,7 @@ class Vidu2TextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu2TextToVideoNode", display_name="Vidu2 Text-to-Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate video from a text prompt", inputs=[ IO.Combo.Input("model", options=["viduq2"]), @@ -584,7 +584,7 @@ class Vidu2ImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu2ImageToVideoNode", display_name="Vidu2 Image-to-Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate a video from an image and an optional prompt.", inputs=[ IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), @@ -714,7 +714,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu2ReferenceVideoNode", display_name="Vidu2 Reference-to-Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate a video from multiple reference images and a prompt.", inputs=[ IO.Combo.Input("model", options=["viduq2"]), @@ -849,7 +849,7 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu2StartEndToVideoNode", display_name="Vidu2 Start/End Frame-to-Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate a video from a start frame, an end frame, and a prompt.", inputs=[ IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), @@ -969,7 +969,7 @@ class ViduExtendVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduExtendVideoNode", display_name="Vidu Video Extension", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Extend an existing video by generating additional frames.", inputs=[ IO.DynamicCombo.Input( @@ -1138,7 +1138,7 @@ class ViduMultiFrameVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduMultiFrameVideoNode", display_name="Vidu Multi-Frame Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate a video with multiple keyframe transitions.", inputs=[ IO.Combo.Input("model", options=["viduq2-pro", "viduq2-turbo"]), @@ -1284,7 +1284,7 @@ class Vidu3TextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu3TextToVideoNode", display_name="Vidu Q3 Text-to-Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate video from a text prompt.", inputs=[ IO.DynamicCombo.Input( @@ -1429,7 +1429,7 @@ class Vidu3ImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu3ImageToVideoNode", display_name="Vidu Q3 Image-to-Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate a video from an image and an optional prompt.", inputs=[ IO.DynamicCombo.Input( @@ -1571,7 +1571,7 @@ class Vidu3StartEndToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu3StartEndToVideoNode", display_name="Vidu Q3 Start/End Frame-to-Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate a video from a start frame, an end frame, and a prompt.", inputs=[ IO.DynamicCombo.Input( diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index a235dc387..b7b97d70f 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -61,7 +61,7 @@ class WanTextToImageApi(IO.ComfyNode): return IO.Schema( node_id="WanTextToImageApi", display_name="Wan Text to Image", - category="image/partner/Wan", + category="partner/image/Wan", description="Generates an image based on a text prompt.", inputs=[ IO.Combo.Input( @@ -184,7 +184,7 @@ class WanImageToImageApi(IO.ComfyNode): return IO.Schema( node_id="WanImageToImageApi", display_name="Wan Image to Image", - category="image/partner/Wan", + category="partner/image/Wan", description="Generates an image from one or two input images and a text prompt. " "The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).", inputs=[ @@ -312,7 +312,7 @@ class WanTextToVideoApi(IO.ComfyNode): return IO.Schema( node_id="WanTextToVideoApi", display_name="Wan Text to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Generates a video based on a text prompt.", inputs=[ IO.Combo.Input( @@ -495,7 +495,7 @@ class WanImageToVideoApi(IO.ComfyNode): return IO.Schema( node_id="WanImageToVideoApi", display_name="Wan Image to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Generates a video from the first frame and a text prompt.", inputs=[ IO.Combo.Input( @@ -674,7 +674,7 @@ class WanReferenceVideoApi(IO.ComfyNode): return IO.Schema( node_id="WanReferenceVideoApi", display_name="Wan Reference to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Use the character and voice from input videos, combined with a prompt, " "to generate a new video that maintains character consistency.", inputs=[ @@ -828,7 +828,7 @@ class Wan2TextToVideoApi(IO.ComfyNode): return IO.Schema( node_id="Wan2TextToVideoApi", display_name="Wan 2.7 Text to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Generates a video based on a text prompt using the Wan 2.7 model.", inputs=[ IO.DynamicCombo.Input( @@ -981,7 +981,7 @@ class Wan2ImageToVideoApi(IO.ComfyNode): return IO.Schema( node_id="Wan2ImageToVideoApi", display_name="Wan 2.7 Image to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Generate a video from a first-frame image, with optional last-frame image and audio.", inputs=[ IO.DynamicCombo.Input( @@ -1152,7 +1152,7 @@ class Wan2VideoContinuationApi(IO.ComfyNode): return IO.Schema( node_id="Wan2VideoContinuationApi", display_name="Wan 2.7 Video Continuation", - category="video/partner/Wan", + category="partner/video/Wan", description="Continue a video from where it left off, with optional last-frame control.", inputs=[ IO.DynamicCombo.Input( @@ -1319,7 +1319,7 @@ class Wan2VideoEditApi(IO.ComfyNode): return IO.Schema( node_id="Wan2VideoEditApi", display_name="Wan 2.7 Video Edit", - category="video/partner/Wan", + category="partner/video/Wan", description="Edit a video using text instructions, reference images, or style transfer.", inputs=[ IO.DynamicCombo.Input( @@ -1477,7 +1477,7 @@ class Wan2ReferenceVideoApi(IO.ComfyNode): return IO.Schema( node_id="Wan2ReferenceVideoApi", display_name="Wan 2.7 Reference to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Generate a video featuring a person or object from reference materials. " "Supports single-character performances and multi-character interactions.", inputs=[ @@ -1651,7 +1651,7 @@ class HappyHorseTextToVideoApi(IO.ComfyNode): return IO.Schema( node_id="HappyHorseTextToVideoApi", display_name="HappyHorse Text to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Generates a video based on a text prompt using the HappyHorse model.", inputs=[ IO.DynamicCombo.Input( @@ -1775,7 +1775,7 @@ class HappyHorseImageToVideoApi(IO.ComfyNode): return IO.Schema( node_id="HappyHorseImageToVideoApi", display_name="HappyHorse Image to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Generate a video from a first-frame image using the HappyHorse model.", inputs=[ IO.DynamicCombo.Input( @@ -1905,7 +1905,7 @@ class HappyHorseVideoEditApi(IO.ComfyNode): return IO.Schema( node_id="HappyHorseVideoEditApi", display_name="HappyHorse Video Edit", - category="video/partner/Wan", + category="partner/video/Wan", description="Edit a video using text instructions or reference images with the HappyHorse model. " "Output duration is 3-15s and matches the input video; inputs longer than 15s are truncated.", inputs=[ @@ -2046,7 +2046,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode): return IO.Schema( node_id="HappyHorseReferenceVideoApi", display_name="HappyHorse Reference to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Generate a video featuring a person or object from reference materials with the HappyHorse " "model. Supports single-character performances and multi-character interactions.", inputs=[ diff --git a/comfy_api_nodes/nodes_wavespeed.py b/comfy_api_nodes/nodes_wavespeed.py index a250015c3..5839f9d37 100644 --- a/comfy_api_nodes/nodes_wavespeed.py +++ b/comfy_api_nodes/nodes_wavespeed.py @@ -27,7 +27,7 @@ class WavespeedFlashVSRNode(IO.ComfyNode): return IO.Schema( node_id="WavespeedFlashVSRNode", display_name="FlashVSR Video Upscale", - category="video/partner/WaveSpeed", + category="partner/video/WaveSpeed", description="Fast, high-quality video upscaler that " "boosts resolution and restores clarity for low-resolution or blurry footage.", inputs=[ @@ -98,7 +98,7 @@ class WavespeedImageUpscaleNode(IO.ComfyNode): return IO.Schema( node_id="WavespeedImageUpscaleNode", display_name="WaveSpeed Image Upscale", - category="image/partner/WaveSpeed", + category="partner/image/WaveSpeed", description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.", inputs=[ IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]), diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 5738df57f..a1b5d599c 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -469,6 +469,11 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input input_container = None output_container = None + # get_stream_source() is untrimmed, so apply the trim window in this same pass. + # start_time is normalized (>= 0); duration == 0 means "until the end". + start_time, duration = video.get_active_trim_window() + trimming = bool(start_time or duration) + try: input_source = video.get_stream_source() input_container = av.open(input_source, mode="r") @@ -487,16 +492,45 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input audio_stream.layout = stream.layout break + in_video = input_container.streams.video[0] + start_pts = int(start_time / in_video.time_base) if trimming else 0 + end_pts = int((start_time + duration) / in_video.time_base) if duration else None + if start_pts: + input_container.seek(start_pts, stream=in_video) + + encoded = 0 for frame in input_container.decode(video=0): + if trimming: + if frame.pts is None or frame.pts < start_pts: + continue + if end_pts is not None and frame.pts >= end_pts: + break frame = frame.reformat(width=out_w, height=out_h, format="yuv420p") + # Re-wrap as a fresh frame: dropping irregular source timestamps (VFR/AVI/GIF/...) + # lets the encoder assign clean ones and avoids mp4 muxer errors. + frame = av.VideoFrame.from_ndarray(frame.to_ndarray(format="yuv420p"), format="yuv420p") for packet in video_stream.encode(frame): output_container.mux(packet) + encoded += 1 for packet in video_stream.encode(): output_container.mux(packet) + if encoded == 0: + raise ValueError( + f"resize produced no frames (start_time={start_time}, duration={duration} " + "selected nothing from the source)" + ) + if audio_stream is not None: input_container.seek(0) for audio_frame in input_container.decode(audio=0): + if trimming: + if audio_frame.time is None or audio_frame.time < start_time: + continue + if duration and audio_frame.time > start_time + duration: + break + # Carry odd audio time bases the mp4 muxer rejects; reset pts, encoder assigns clean ones (MP3-in-AVI) + audio_frame.pts = None for packet in audio_stream.encode(audio_frame): output_container.mux(packet) for packet in audio_stream.encode(): diff --git a/comfy_extras/nodes_chroma_radiance.py b/comfy_extras/nodes_chroma_radiance.py index ca427e5cb..a4f673001 100644 --- a/comfy_extras/nodes_chroma_radiance.py +++ b/comfy_extras/nodes_chroma_radiance.py @@ -65,6 +65,12 @@ class ChromaRadianceOptions(io.ComfyNode): tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).", advanced=True, ), + io.Boolean.Input( + id="force_sequential_txt_ids", + default=False, + tooltip="Force usage of sequential text token IDs instead of zeroes. Should be used for checkpoints from 2026-05-22 to 2026-06-01 that are trained in this way but do not contain the __sequential__ key in the state dict.", + advanced=True, + ), ], outputs=[io.Model.Output()], ) @@ -78,11 +84,15 @@ class ChromaRadianceOptions(io.ComfyNode): start_sigma: float, end_sigma: float, nerf_tile_size: int, + force_sequential_txt_ids: bool, ) -> io.NodeOutput: radiance_options = {} if nerf_tile_size >= 0: radiance_options["nerf_tile_size"] = nerf_tile_size + if force_sequential_txt_ids: + radiance_options["use_sequential_txt_ids"] = True + if not radiance_options: return io.NodeOutput(model) diff --git a/comfy_extras/nodes_gaussian_splat.py b/comfy_extras/nodes_gaussian_splat.py new file mode 100644 index 000000000..2ba3a3820 --- /dev/null +++ b/comfy_extras/nodes_gaussian_splat.py @@ -0,0 +1,1664 @@ +# Generic utility nodes for the SPLAT type (3D gaussian splats) + +import gzip +import logging +import math +import struct +from io import BytesIO + +import numpy as np +import torch +from typing_extensions import override +from scipy.ndimage import map_coordinates, minimum as _ndi_minimum, maximum as _ndi_maximum +from scipy.sparse import coo_matrix +from scipy.sparse.csgraph import connected_components + +import comfy.model_management +import comfy.utils +from comfy_api.latest import ComfyExtension, IO, Types +from comfy_extras.nodes_save_3d import pack_variable_mesh_batch +from server import PromptServer + +_C0 = 0.28209479177387814 # SH band-0 constant: DC coefficient -> base RGB + + +def _srgb_to_linear(c): + return torch.where(c <= 0.04045, c / 12.92, ((c.clamp_min(0) + 0.055) / 1.055) ** 2.4) + + +def _linear_to_srgb(c): + return torch.where(c <= 0.0031308, c * 12.92, 1.055 * c.clamp_min(0) ** (1 / 2.4) - 0.055) + + +def _real_len(g: Types.SPLAT, i: int) -> int: + # Real splat count of batch item i (honors variable-length `counts`). + return int(g.counts[i].item()) if g.counts is not None else g.positions.shape[1] + + +def _hex_to_rgb(h: str) -> tuple[float, float, float]: + # "#RRGGBB" -> (r,g,b) in [0,1]; falls back to black. + h = h.lstrip("#") + if len(h) != 6: + return (0.0, 0.0, 0.0) + return tuple(int(h[i:i + 2], 16) / 255.0 for i in (0, 2, 4)) + + +def _quantile(x, q): + # torch.quantile errors above 2**24 elements; stride-subsample large inputs for the estimate. + lim = 1 << 24 + if x.numel() > lim: + x = x[:: x.numel() // lim + 1] + return torch.quantile(x, q) + + +def _gaussian_ply_bytes(positions, scales, rotations, opacities, sh) -> bytes: + """Serialize render-ready gaussian tensors as a binary 3DGS .ply. + + positions (N,3) world; scales (N,3) linear; rotations (N,4) quat wxyz; opacities (N,1) in [0,1]; + sh (N,K,3) SH coefficients. Activated values are inverted to the standard 3D gaussian splat storage convention + (log scale, logit opacity). + """ + xyz = positions.cpu().numpy().astype(np.float32) + n = xyz.shape[0] + if n == 0: + raise ValueError("SplatToFile3D: gaussian is empty") + normals = np.zeros_like(xyz) + f = sh.cpu().numpy().astype(np.float32) # (N, K, 3) + f_dc = f[:, 0, :] # (N, 3) + f_rest = f[:, 1:, :].transpose(0, 2, 1).reshape(n, -1) # (N, 3*(K-1)) channel-major + op = opacities.cpu().numpy().astype(np.float32).reshape(n, 1).clip(1e-6, 1 - 1e-6) + op = np.log(op / (1.0 - op)) # inverse sigmoid (logit) + scale = np.log(scales.cpu().numpy().astype(np.float32).clip(min=1e-8)) + rot = rotations.cpu().numpy().astype(np.float32) # (N, 4) + + attrs = (['x', 'y', 'z', 'nx', 'ny', 'nz'] + + [f'f_dc_{i}' for i in range(3)] + + [f'f_rest_{i}' for i in range(f_rest.shape[1])] + + ['opacity'] + [f'scale_{i}' for i in range(3)] + [f'rot_{i}' for i in range(4)]) + elements = np.empty(n, dtype=[(a, 'f4') for a in attrs]) + elements[:] = list(map(tuple, np.concatenate([xyz, normals, f_dc, f_rest, op, scale, rot], axis=1))) + + header = "ply\nformat binary_little_endian 1.0\n" + f"element vertex {n}\n" + header += "".join(f"property float {a}\n" for a in attrs) + "end_header\n" + return header.encode('ascii') + elements.tobytes() + + +# .ksplat (mkkellogg SplatBuffer) level 0, SH degree 0: 4096-byte header, one 1024-byte section header, +# then N 44-byte records. Bucketing/quantization only exist at levels >= 1. See SplatBuffer.js. +_KSPLAT_HEADER_BYTES = 4096 +_KSPLAT_SECTION_HEADER_BYTES = 1024 +_KSPLAT_BYTES_PER_SPLAT = 44 # center 12 + scale 12 + rotation 16 + color(RGBA u8) 4 +_KSPLAT_VERSION = (0, 1) # SplatBuffer CurrentMajor/MinorVersion + + +def _gaussian_ksplat_bytes(positions, scales, rotations, opacities, sh) -> bytes: + """Serialize gaussian tensors as a level-0, SH degree-0 .ksplat (linear scale, opacity in color alpha). + + positions (N,3) world; scales (N,3) linear; rotations (N,4) wxyz; opacities (N,1) in [0,1]; sh (N,K,3). + """ + xyz = positions.cpu().numpy().astype(np.float32) + n = xyz.shape[0] + if n == 0: + raise ValueError("SplatToFile3D: gaussian is empty") + scale = scales.cpu().numpy().astype(np.float32) + rot = rotations.cpu().numpy().astype(np.float32) # wxyz, mirrors the .ply rot order + rot = rot / np.linalg.norm(rot, axis=1, keepdims=True).clip(1e-12) + rgb = np.clip(sh[:, 0, :].cpu().numpy().astype(np.float32) * _C0 + 0.5, 0, 1) + op = opacities.cpu().numpy().astype(np.float32).reshape(n, 1).clip(0, 1) + rgba = np.round(np.concatenate([rgb, op], axis=1) * 255.0).astype(np.uint8) # (N, 4) RGBA + + # 44-byte record: float center(3) + scale(3) + rot(4), then uint8 rgba(4). + floats = np.concatenate([xyz, scale, rot], axis=1).astype(' bytes: + """Serialize gaussian tensors as a gzip-compressed .spz (Niantic v2, SH degree 0, base color only). + + positions (N,3) world; scales (N,3) linear; rotations (N,4) wxyz; opacities (N,1) in [0,1]; sh (N,K,3). + """ + xyz = positions.cpu().numpy().astype(np.float32) + n = xyz.shape[0] + if n == 0: + raise ValueError("SplatToFile3D: gaussian is empty") + + # Positions: fixed point, masked to 24 bits, little-endian 3-byte words. + fixed = 1 << _SPZ_FRACTIONAL_BITS + qi = np.clip(np.round(xyz * fixed), -(1 << 23), (1 << 23) - 1).astype(np.int32) + qu = (qi & 0xFFFFFF).astype(np.uint32) + pos = np.stack([qu & 0xFF, (qu >> 8) & 0xFF, (qu >> 16) & 0xFF], axis=-1).reshape(n, 9).astype(np.uint8) + + alpha = np.round(opacities.cpu().numpy().astype(np.float32).reshape(n) * 255.0).clip(0, 255).astype(np.uint8) + + rgb = sh[:, 0, :].cpu().numpy().astype(np.float32) * _C0 + 0.5 + col = np.round(((rgb - 0.5) / _SPZ_COLOR_SCALE + 0.5) * 255.0).clip(0, 255).astype(np.uint8) # (N,3) + + sln = np.log(scales.cpu().numpy().astype(np.float32).clip(min=1e-9)) + scb = np.round((sln + 10.0) * 16.0).clip(0, 255).astype(np.uint8) # (N,3) inverts exp(b/16-10) + + rot = rotations.cpu().numpy().astype(np.float32) # wxyz + rot = rot / np.linalg.norm(rot, axis=1, keepdims=True).clip(1e-12) + rot[rot[:, 0] < 0] *= -1.0 # canonical w >= 0 (w dropped on decode) + rotb = np.round((rot[:, 1:4] + 1.0) * 127.5).clip(0, 255).astype(np.uint8) # (N,3) x,y,z + + header = bytearray(16) + struct.pack_into(' (positions, scales linear, rotations wxyz, opacities [0,1], sh (N,K,3)) ---- +# Inverse of the writers above and of spark's loaders. ksplat/splat/spz carry base color only (SH degree 0 +# -> K=1); .ply round-trips full SH. None of the formats flip axes, so import is the identity of export. +_PLY_DTYPES = {'char': 'i1', 'uchar': 'u1', 'short': 'i2', 'ushort': 'u2', 'int': 'i4', 'uint': 'u4', + 'float': 'f4', 'double': 'f8', 'int8': 'i1', 'uint8': 'u1', 'int16': 'i2', 'uint16': 'u2', + 'int32': 'i4', 'uint32': 'u4', 'float32': 'f4', 'float64': 'f8'} +_KSPLAT_COMPRESSION = { # level -> (bytesPerCenter, scale, rotation, color, shComponent, defaultScaleRange) + 0: (12, 12, 16, 4, 4, 1), 1: (6, 6, 8, 4, 2, 32767), 2: (6, 6, 8, 4, 1, 32767)} +_KSPLAT_SH_COMPONENTS = {0: 0, 1: 9, 2: 24, 3: 45} + + +def _rgb_to_sh_dc(rgb): + return ((np.asarray(rgb, np.float32) - 0.5) / _C0)[:, None, :] # (N,3) base color -> (N,1,3) SH DC + + +def _norm_quat(q): + return q / np.linalg.norm(q, axis=1, keepdims=True).clip(1e-12) + + +def _parse_ply_gaussian(data: bytes): + end = data.find(b'end_header') + if end < 0: + raise ValueError("File3DToSplat: not a PLY (missing end_header)") + header = data[:end].decode('ascii', 'replace') + body = end + len(b'end_header') + body += 2 if data[body:body + 2] == b'\r\n' else 1 + count, props, in_vertex = 0, [], False + for line in header.splitlines(): + p = line.split() + if not p: + continue + if p[0] == 'format' and p[1] != 'binary_little_endian': + raise ValueError(f"File3DToSplat: unsupported PLY format '{p[1]}' (need binary_little_endian)") + if p[0] == 'element': + in_vertex = p[1] == 'vertex' + if in_vertex: + count = int(p[2]) + elif p[0] == 'property' and in_vertex: + if p[1] == 'list': + raise ValueError("File3DToSplat: PLY vertex has list properties (unsupported)") + props.append((p[2], '<' + _PLY_DTYPES[p[1]])) + arr = np.frombuffer(data, np.dtype(props), count=count, offset=body) + names = arr.dtype.names + c = lambda k: arr[k].astype(np.float32) + n = count + + xyz = np.stack([c('x'), c('y'), c('z')], 1) + if 'scale_0' in names: + scale = np.exp(np.stack([c('scale_0'), c('scale_1'), c('scale_2')], 1)) # 3DGS stores log scale + else: + scale = np.full((n, 3), 0.01, np.float32) + if 'rot_0' in names: + rot = _norm_quat(np.stack([c('rot_0'), c('rot_1'), c('rot_2'), c('rot_3')], 1)) # wxyz + else: + rot = np.tile(np.array([1, 0, 0, 0], np.float32), (n, 1)) + opacity = 1.0 / (1.0 + np.exp(-c('opacity'))) if 'opacity' in names else np.ones(n, np.float32) + + if 'f_dc_0' in names: + dc = np.stack([c('f_dc_0'), c('f_dc_1'), c('f_dc_2')], 1) # (N,3) + rest = sorted((k for k in names if k.startswith('f_rest_')), key=lambda s: int(s.split('_')[-1])) + if rest: + r = np.stack([c(k) for k in rest], 1) # (N, 3*(K-1)) channel-major + kk = r.shape[1] // 3 + 1 + r = r.reshape(n, 3, kk - 1).transpose(0, 2, 1) # -> (N, K-1, 3) + sh = np.concatenate([dc[:, None, :], r], 1) + else: + sh = dc[:, None, :] + elif 'red' in names: + sh = _rgb_to_sh_dc(np.stack([c('red'), c('green'), c('blue')], 1) / 255.0) + else: + sh = np.zeros((n, 1, 3), np.float32) + return xyz, scale, rot, opacity, sh + + +def _parse_splat_gaussian(data: bytes): + # antimatter15 .splat: 32-byte records (f32 xyz, f32 scale, u8 rgba, u8 quat as (b-128)/128 wxyz). + if len(data) % 32 != 0: + raise ValueError("File3DToSplat: .splat size is not a multiple of 32 bytes") + rec = np.frombuffer(data, np.dtype([('xyz', ' 0: + ct, ft = (' full_splats: + lengths = np.frombuffer(data, '> 30) & 3 + q = np.zeros((n, 4), np.float32) # x,y,z,w + remaining, sumsq = combined.copy(), np.zeros(n, np.float64) + for comp in (3, 2, 1, 0): + active = comp != largest + value = (remaining & 0x1FF).astype(np.float64) + sign = (remaining >> 9) & 1 + remaining = np.where(active, remaining >> 10, remaining) + val = (1.0 / math.sqrt(2)) * (value / 0x1FF) + val = np.where(sign == 1, -val, val) + q[active, comp] = val[active] + sumsq += np.where(active, val * val, 0.0) + q[np.arange(n), largest] = np.sqrt(np.clip(1.0 - sumsq, 0, None)) + rot = _norm_quat(np.stack([q[:, 3], q[:, 0], q[:, 1], q[:, 2]], 1)) # xyzw -> wxyz + else: + qb = np.frombuffer(raw, np.uint8, count=n * 3, offset=off).reshape(n, 3).astype(np.float32) + xq = qb / 127.5 - 1.0 + w = np.sqrt(np.clip(1.0 - (xq ** 2).sum(1), 0, None)) + rot = _norm_quat(np.concatenate([w[:, None], xq], 1)) # wxyz + return xyz, scale, rot, alpha, _rgb_to_sh_dc(rgb) + + +_GAUSSIAN_PARSERS = {"ply": _parse_ply_gaussian, "splat": _parse_splat_gaussian, + "ksplat": _parse_ksplat_gaussian, "spz": _parse_spz_gaussian} + + +def _detect_splat_format(data: bytes) -> str: + if data[:3] == b'ply': + return "ply" + if data[:2] == b'\x1f\x8b': # gzip -> spz + return "spz" + if len(data) >= 2 and data[0] == 0 and data[1] >= 1: # ksplat version 0.x header + return "ksplat" + if len(data) % 32 == 0: + return "splat" + raise ValueError("File3DToSplat: could not determine splat format from contents") + + +def _gaussian_item(g: Types.SPLAT, i: int, device): + # Slice batch item i to its real length, as float32 torch tensors on `device` (SH DC -> base RGB). + end = _real_len(g, i) + to = lambda a: a.to(device=device, dtype=torch.float32) + xyz = to(g.positions[i, :end]) + rgb = (to(g.sh[i, :end, 0, :]) * _C0 + 0.5).clamp(0, 1) + opacity = to(g.opacities[i, :end]).reshape(-1) + scale = to(g.scales[i, :end]) + rot = to(g.rotations[i, :end]) + return xyz, rgb, opacity, scale, rot + + +def _quat_to_mat(q): + # q: (N, 4) wxyz, normalized -> (N, 3, 3) + q = q / q.norm(dim=-1, keepdim=True).clamp_min(1e-12) + w, x, y, z = q.unbind(-1) + return torch.stack([ + 1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y), + 2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x), + 2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y), + ], dim=-1).reshape(-1, 3, 3) + + +def _quat_mul(a, b): + # Hamilton product a (x) b, wxyz. + aw, ax, ay, az = a.unbind(-1) + bw, bx, by, bz = b.unbind(-1) + return torch.stack([ + aw * bw - ax * bx - ay * by - az * bz, + aw * bx + ax * bw + ay * bz - az * by, + aw * by - ax * bz + ay * bw + az * bx, + aw * bz + ax * by - ay * bx + az * bw, + ], dim=-1) + + +def _euler_to_quat(rx, ry, rz): + # Degrees, applied as Rz @ Ry @ Rx (rotate about X, then Y, then Z in world). Returns wxyz. + c, s = np.cos(np.radians([rx, ry, rz]) / 2.0), np.sin(np.radians([rx, ry, rz]) / 2.0) + qx = torch.tensor([c[0], s[0], 0.0, 0.0], dtype=torch.float32) + qy = torch.tensor([c[1], 0.0, s[1], 0.0], dtype=torch.float32) + qz = torch.tensor([c[2], 0.0, 0.0, s[2]], dtype=torch.float32) + return _quat_mul(_quat_mul(qz, qy), qx) + + +def _mat_to_quat(m): + # Rotation matrix (..., 3, 3) -> quaternion (..., 4) wxyz. Batched; builds the four candidate quaternions + # and keeps the one with the largest component (numerically stable across all rotations). + m00, m11, m22 = m[..., 0, 0], m[..., 1, 1], m[..., 2, 2] + m21, m12 = m[..., 2, 1], m[..., 1, 2] + m02, m20 = m[..., 0, 2], m[..., 2, 0] + m10, m01 = m[..., 1, 0], m[..., 0, 1] + q2 = torch.stack([1 + m00 + m11 + m22, 1 + m00 - m11 - m22, + 1 - m00 + m11 - m22, 1 - m00 - m11 + m22], -1) # 4 * (w^2, x^2, y^2, z^2) + cand = torch.stack([ + torch.stack([q2[..., 0], m21 - m12, m02 - m20, m10 - m01], -1), + torch.stack([m21 - m12, q2[..., 1], m10 + m01, m02 + m20], -1), + torch.stack([m02 - m20, m10 + m01, q2[..., 2], m12 + m21], -1), + torch.stack([m10 - m01, m02 + m20, m12 + m21, q2[..., 3]], -1), + ], -2) # (...,4,4) candidates, rows = wxyz + sel = q2.argmax(-1) + q = torch.gather(cand, -2, sel[..., None, None].expand(sel.shape + (1, 4)))[..., 0, :] + return q / q.norm(dim=-1, keepdim=True).clamp_min(1e-12) + + +class SplatToFile3D(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SplatToFile3D", + display_name="Create 3D File (from Splat)", + search_aliases=["gaussian to ply", "splat to file", "export gaussian"], + category="3d/splat", + description="Serialize a gaussian splat to a File3D object for Save / Preview 3D nodes. " + "Supports one item per batch only.", + inputs=[ + IO.Splat.Input("splat"), + IO.Combo.Input("format", options=["ply", "ksplat", "spz"], # TODO: add "splat" when we have a writer for it + tooltip="ply: standard 3D Gaussian Splat with full spherical harmonics. " + "ksplat: mkkellogg SplatBuffer (level 0, uncompressed), base color only " + "spz: Niantic gzip-compressed (~10x smaller), base color only " + ), + ], + outputs=[IO.File3DAny.Output(display_name="model_3d")], + ) + + @classmethod + def execute(cls, splat, format="ply") -> IO.NodeOutput: + if splat.positions.shape[0] > 1: + logging.warning("SplatToFile3D supports one item per batch only. Got %d; using first.", splat.positions.shape[0]) + end = _real_len(splat, 0) + writer = {"ksplat": _gaussian_ksplat_bytes, "spz": _gaussian_spz_bytes}.get(format, _gaussian_ply_bytes) + data = writer(splat.positions[0, :end], splat.scales[0, :end], + splat.rotations[0, :end], splat.opacities[0, :end], splat.sh[0, :end]) + return IO.NodeOutput(Types.File3D(BytesIO(data), file_format=format)) + + +class File3DToSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="File3DToSplat", + display_name="Get Splat", + search_aliases=["load splat", "ply to splat", "import splat", "file to splat"], + category="3d/splat", + description="Parse a splat File3D into a gaussian splat. Inverse of Create 3D File (from Splat). " + "Supported format: PLY, SPLAT, KSPLAT, SPZ. PLY carries full spherical harmonics, " + "the other formats are base color only. Format is auto-detected from the file contents.", + inputs=[ + IO.MultiType.Input( + IO.File3DAny.Input("model_3d"), + types=[IO.File3DPLY, IO.File3DSPLAT, IO.File3DKSPLAT, IO.File3DSPZ], + tooltip="A gaussian splat 3D file", + ), + ], + outputs=[IO.Splat.Output(display_name="splat")], + ) + + @classmethod + def execute(cls, model_3d: Types.File3D) -> IO.NodeOutput: + data = model_3d.get_bytes() + fmt = (model_3d.format or "").lower() + parser = _GAUSSIAN_PARSERS.get(fmt) or _GAUSSIAN_PARSERS[_detect_splat_format(data)] + xyz, scale, rot, opacity, sh = parser(data) + + t = lambda a: torch.from_numpy(np.ascontiguousarray(a)).float() + splat = Types.SPLAT( + t(xyz)[None], # (1, N, 3) + t(scale)[None], # (1, N, 3) linear + t(rot)[None], # (1, N, 4) wxyz + t(opacity).reshape(1, -1, 1), # (1, N, 1) + t(sh)[None], # (1, N, K, 3) + ) + return IO.NodeOutput(splat) + + +def _view_matrix_t(yaw_deg, pitch_deg, device): + y, p = math.radians(yaw_deg), math.radians(pitch_deg) + cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p) + Ry = torch.tensor([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], device=device) + Rx = torch.tensor([[1, 0, 0], [0, cp, -sp], [0, sp, cp]], device=device) + return Rx @ Ry + + +def _camera_basis(camera_info, dev): + # Look-at basis in the splat frame, named by their projection rows: right = image +x, up = image +y + # (down, since yflip=1), fwd = view/depth axis (eye -> scene). Load3D is three.js (right-handed, Y-up, + # camera looks down -Z); the splat is 3DGS (Y-down, Z-forward). World -> splat is a 180 deg rotation + # about X: (x, y, z) -> (x, -y, -z) (det +1, no mirror, no axis swap). + pos, tgt = camera_info.get("position", {}), camera_info.get("target", {}) + m = lambda d: torch.tensor([float(d.get("x", 0.0)), -float(d.get("y", 0.0)), -float(d.get("z", 0.0))], device=dev) + eye, target = m(pos), m(tgt) + mv = lambda v: torch.stack([v[0], -v[1], -v[2]]) # same world->splat map, for direction vectors + n = lambda v: v / v.norm().clamp_min(1e-8) + q = camera_info.get("quaternion") + if q: # exact camera world rotation (incl. roll) + qwxyz = torch.tensor([float(q.get("w", 1.0)), float(q.get("x", 0.0)), + float(q.get("y", 0.0)), float(q.get("z", 0.0))], device=dev) + R = _quat_to_mat(qwxyz[None])[0] # columns = camera world axes; looks down local -Z + right = n(mv(R[:, 0])) # camera +X -> image right + up = n(mv(-R[:, 1])) # camera +Y is image up; image-down row is its negative + fwd = n(mv(-R[:, 2])) # camera looks down local -Z -> view direction + return eye, target, right, up, fwd + fwd = n(target - eye) # no quaternion: orbit-consistent, roll-free + yaw = math.degrees(math.atan2(-float(fwd[0]), float(fwd[2]))) + pitch = math.degrees(math.asin(max(-1.0, min(1.0, float(fwd[1]))))) + W = _view_matrix_t(yaw, pitch, dev) + return eye, target, W[0], W[1], W[2] + + +def _lookat_quat_wxyz(position, target, dev): + # three.js lookAt in world frame: camera local +Z = (eye - target), up = world +Y. Returns wxyz. + z = position - target + z = z / z.norm().clamp_min(1e-8) + up0 = torch.tensor([0.0, 1.0, 0.0], device=dev) + if z.dot(up0).abs() > 0.999: # looking straight up/down + up0 = torch.tensor([0.0, 0.0, 1.0], device=dev) + x = torch.linalg.cross(up0, z) + x = x / x.norm().clamp_min(1e-8) + y = torch.linalg.cross(z, x) + R = torch.stack([x, y, z], dim=1) # columns = camera world axes + return _mat_to_quat(R[None])[0] + + +def _lookat_camera_info(position, target, fov, dev, zoom=1.0, camera_type="perspective", roll=0.0): + # Build a camera_info from a world-space (right-handed, Y-up) eye + look-at target; up = world +Y. + pos = torch.as_tensor(position, dtype=torch.float32, device=dev) + tgt = torch.as_tensor(target, dtype=torch.float32, device=dev) + q = _lookat_quat_wxyz(pos, tgt, dev) + if roll: # roll about the view axis (camera local Z) + a = math.radians(roll) + qz = torch.tensor([math.cos(a / 2), 0.0, 0.0, math.sin(a / 2)], device=dev) + q = _quat_mul(q[None], qz[None])[0] + xyz = lambda v: {"x": float(v[0]), "y": float(v[1]), "z": float(v[2])} + return {"position": xyz(pos), "target": xyz(tgt), + "quaternion": {"x": float(q[1]), "y": float(q[2]), "z": float(q[3]), "w": float(q[0])}, + "fov": float(fov), "cameraType": str(camera_type), "zoom": float(zoom)} + + +def _quat_camera_info(position, quat_xyzw, fov, dev, zoom=1.0, camera_type="perspective"): + # camera_info from an explicit world position + camera-rotation quaternion (three.js: looks down local -Z). + pos = torch.as_tensor(position, dtype=torch.float32, device=dev) + qx, qy, qz, qw = (float(c) for c in quat_xyzw) + qwxyz = torch.tensor([qw, qx, qy, qz], dtype=torch.float32, device=dev) + qwxyz = qwxyz / qwxyz.norm().clamp_min(1e-8) + R = _quat_to_mat(qwxyz[None])[0] + tgt = pos - R[:, 2] # look one unit down local -Z + xyz = lambda v: {"x": float(v[0]), "y": float(v[1]), "z": float(v[2])} + return {"position": xyz(pos), "target": xyz(tgt), + "quaternion": {"x": float(qwxyz[1]), "y": float(qwxyz[2]), "z": float(qwxyz[3]), "w": float(qwxyz[0])}, + "fov": float(fov), "cameraType": str(camera_type), "zoom": float(zoom)} + + +def _orbit_camera_info(yaw, pitch, distance, fov, pivot_splat, dev): + # Orbit helper for RenderSplat's default camera: yaw/pitch about `pivot_splat` (splat frame) at `distance`. + # World<->splat is the (x,-y,-z) map, so _camera_basis recovers exactly _view_matrix_t(yaw, pitch). + y, p = math.radians(yaw), math.radians(pitch) + cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p) + fwd_splat = torch.tensor([-cp * sy, sp, cp * cy], device=dev) # == _view_matrix_t(yaw, pitch)[2] + m = lambda v: torch.stack([v[0], -v[1], -v[2]]) # splat<->world (its own inverse) + return _lookat_camera_info(m(pivot_splat - distance * fwd_splat), m(pivot_splat), fov, dev) + + +def _orbit_camera_info_yaw(camera_info, angle_deg, dev): + # Turntable: rigidly rotate a camera_info about world +Y around its target by angle_deg. Returns a new dict. + a = math.radians(angle_deg) + ca, sa = math.cos(a), math.sin(a) + v = lambda d: torch.tensor([float(d.get("x", 0.0)), float(d.get("y", 0.0)), float(d.get("z", 0.0))], device=dev) + pos, tgt = v(camera_info.get("position", {})), v(camera_info.get("target", {})) + Ry = torch.tensor([[ca, 0.0, sa], [0.0, 1.0, 0.0], [-sa, 0.0, ca]], device=dev) + new_pos = tgt + Ry @ (pos - tgt) + q = camera_info.get("quaternion") or {} + qcur = torch.tensor([float(q.get("w", 1.0)), float(q.get("x", 0.0)), + float(q.get("y", 0.0)), float(q.get("z", 0.0))], device=dev) + qy = torch.tensor([math.cos(a / 2), 0.0, math.sin(a / 2), 0.0], device=dev) # world +Y rotation + qn = _quat_mul(qy[None], qcur[None])[0] + xyz = lambda t: {"x": float(t[0]), "y": float(t[1]), "z": float(t[2])} + return {**camera_info, "position": xyz(new_pos), + "quaternion": {"x": float(qn[1]), "y": float(qn[2]), "z": float(qn[3]), "w": float(qn[0])}} + + +def _gauss_blur(x, sigma, dev): + # Separable Gaussian blur of (1, C, H, W). Used to denoise the screen-space normal map. + r = max(1, int(round(3 * sigma))) + k = torch.exp(-0.5 * (torch.arange(-r, r + 1, device=dev, dtype=torch.float32) / sigma) ** 2) + k = k / k.sum() + c = x.shape[1] + x = torch.nn.functional.conv2d(x, k.view(1, 1, 1, -1).expand(c, 1, 1, -1), padding=(0, r), groups=c) + x = torch.nn.functional.conv2d(x, k.view(1, 1, -1, 1).expand(c, 1, -1, 1), padding=(r, 0), groups=c) + return x + + +def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, bg, camera_info, + sharpen=1.0, headlight_shading=0.0, render_style="color"): + # Perspective-correct anisotropic gaussian splat rasterizer. Each splat is weighted by its 3D Gaussian's + # peak along each pixel's ray (AAA / Hahlbohm), composited front-to-back across depth slabs. `render_style` + # selects the image: color / clay / depth / normal. Returns (image HxWx3, coverage mask HxW) on CPU. + dev = comfy.model_management.get_torch_device() + t = lambda a: torch.as_tensor(a, dtype=torch.float32, device=dev) + idev, idtype = comfy.model_management.intermediate_device(), comfy.model_management.intermediate_dtype() + xyz, rgb, opacity = t(xyz), t(rgb).clamp(0, 1), t(opacity).reshape(-1) + scale, rot = t(scale) * float(splat_scale), t(rot) + do_linear = render_style == "color" # colour blends in linear light, re-encoded at the end + if do_linear: + rgb = _srgb_to_linear(rgb) + flat = width * height + bg_t = t(bg) + bg_comp = _srgb_to_linear(bg_t) if do_linear else bg_t # background blended in the same space as the splats + need_depth = render_style == "depth" + need_normal = render_style in ("normal", "clay") or headlight_shading > 0 + + def background_only(): # no splats to rasterize -> just the background + empty mask + img = bg_t.expand(height, width, 3) if render_style == "color" else torch.zeros(height, width, 3, device=dev) + return img.to(idev, idtype), torch.zeros(height, width, device=idev, dtype=idtype) + + if xyz.shape[0] == 0: # empty input (e.g. all culled by opacity_threshold) + return background_only() + + eye, target, right, up, fwd = _camera_basis(camera_info, dev) # all camera state comes from camera_info + W = torch.stack([right, up, fwd], 0) # rows = camera axes (world -> camera) + cam = (xyz - eye) @ W.T + fov = float(camera_info.get("fov", 0) or 0) or 35.0 + zoom = float(camera_info.get("zoom", 1.0) or 1.0) # three.js digital zoom: scales the focal length + is_ortho = str(camera_info.get("cameraType", "")).lower().startswith("ortho") + xc, yc, zc = cam.unbind(-1) + + keep = zc > 1e-2 + xc, yc, zc, rgb, opacity, scale, rot = (a[keep] for a in (xc, yc, zc, rgb, opacity, scale, rot)) + if xc.shape[0] == 0: # nothing in front of the camera -> background only + return background_only() + if render_style == "clay": + rgb = torch.full_like(rgb, 0.75) # neutral albedo -> shading shows pure geometry + + f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) * zoom # fov over the smaller axis, x camera zoom + cx0, cy0 = width / 2, height / 2 + + # Camera-space 3D covariance per splat: Sigma = (W Rq) diag(scale^2) (W Rq)^T, plus a tiny relative + # regularizer for a stable inverse (a pixel-size Mip low-pass would over-thicken flat surfels and blur). + Mw = W[None] @ _quat_to_mat(rot) # (N,3,3) world -> camera + cam_cov = (Mw * scale.square()[:, None, :]) @ Mw.transpose(1, 2) + cam_cov = cam_cov + (cam_cov.diagonal(dim1=-2, dim2=-1).mean(-1) * 1e-3)[:, None, None] * torch.eye(3, device=dev) + + # Perspective-correct weighting: peak of the 3D Gaussian along each pixel ray. Precompute Si, Si@mu, mu^T Si mu. + mu = torch.stack([xc, yc, zc], -1) + si = torch.linalg.inv(cam_cov) + simu = (si @ mu[:, :, None])[:, :, 0] # (N,3) + musimu = (mu * simu).sum(-1) # (N,) + s00, s01, s02 = si[:, 0, 0], si[:, 0, 1], si[:, 0, 2] + s11, s12, s22 = si[:, 1, 1], si[:, 1, 2], si[:, 2, 2] + simu0, simu1, simu2 = simu.unbind(-1) + if need_normal: # surfel normal = thinnest axis, oriented toward camera + nrm = Mw[torch.arange(Mw.shape[0], device=dev), :, scale.argmin(-1)] # (N,3) camera-space normal + nrm = nrm * torch.where(nrm[:, 2:3] > 0, -1.0, 1.0) # flip so nz <= 0 (faces camera) + + # Screen centre (exact) + footprint radius from the affine 2D projection (used only to size the kernel). + # The image is +y-down, so the projection's y row is unflipped - it matches the splat frame's +Y. + jm = torch.zeros(xc.shape[0], 2, 3, device=dev) + if is_ortho: # parallel projection: screen = s * (xc, yc) + s = f / float((target - eye).norm().clamp_min(1e-6)) # pixels per world unit at the target plane + cx, cy = cx0 + s * xc, cy0 + s * yc + jm[:, 0, 0] = s + jm[:, 1, 1] = s + else: # perspective: screen = f * (xc, yc) / zc + invz = 1.0 / zc + cx, cy = cx0 + f * xc * invz, cy0 + f * yc * invz + jm[:, 0, 0], jm[:, 0, 2] = f * invz, -f * xc * invz.square() + jm[:, 1, 1], jm[:, 1, 2] = f * invz, -f * yc * invz.square() + cov2 = jm @ cam_cov @ jm.transpose(1, 2) + a, b, c = cov2[:, 0, 0], cov2[:, 0, 1], cov2[:, 1, 1] + max_eig = (a + c) * 0.5 + (((a - c) * 0.5).square() + b * b).clamp_min(0).sqrt() + radius = 3.0 * max_eig.clamp_min(1e-8).sqrt() + K = int(min(max(24, min(width, height) // 16), max(2, math.ceil(_quantile(radius, 0.995).item())))) + + # Per-splat kernel size: bucket splats by radius into a coarse ladder of window sizes (global K stays the cap) so + # small splats (the bulk of it) use a small window. + levels = [L for L in (16, 64, 256) if L < K] + [K] + levels_t = torch.tensor(levels, device=dev, dtype=torch.float32) + grids = [] + for L in levels: + rng = torch.arange(-L, L + 1, device=dev, dtype=torch.float32) + gy, gx = torch.meshgrid(rng, rng, indexing="ij") + grids.append((gx.reshape(-1), gy.reshape(-1))) + blevel = torch.bucketize(radius * (4.0 / 3.0), levels_t).clamp_(max=len(levels) - 1) # window >= ~4 sigma + + n = zc.shape[0] + ns = int(min(256, max(1, n // 1000))) # depth slabs: 1 per ~1000 splats, capped + nl = len(levels) + order = torch.argsort(zc) # front (small zc) -> back -> defines the slabs + bounds = torch.linspace(0, n, ns + 1, device=dev).round().long() + rank = torch.empty(n, dtype=torch.long, device=dev) + rank[order] = torch.arange(n, device=dev) # depth rank of each splat + slab_id = (torch.searchsorted(bounds, rank, right=True) - 1).clamp_(0, ns - 1) + key = slab_id * nl + blevel # group by slab, then kernel level (order-free within) + order = torch.argsort(key) + key = key[order] + + cxr, cyr = cx[order].round(), cy[order].round() + s00, s01, s02 = s00[order], s01[order], s02[order] + s11, s12, s22 = s11[order], s12[order], s22[order] + s01b, s02b, s12b = s01 * 2, s02 * 2, s12 * 2 # doubled cross terms for the fused quadratic forms + simu0, simu1, simu2, musimu = simu0[order], simu1[order], simu2[order], musimu[order] + opacity, rgb = opacity[order], rgb[order] + zc_o = zc[order] if need_depth else None + nrm_o = nrm[order] if need_normal else None + mux_o, muy_o, muz_o = (xc[order], yc[order], zc[order]) if is_ortho else (None, None, None) + + # Pack the per-splat scalars into one tensor so each chunk slices once + common = [cxr, cyr, s00, s11, s22, s01b, s02b, s12b, opacity] + pstack = torch.stack(common + ([s02, s12, mux_o, muy_o, muz_o] if is_ortho else [simu0, simu1, simu2, musimu])) + + # Precompute the (slab, level) run table on-GPU and pull it to the CPU once + starts = torch.cat([torch.zeros(1, dtype=torch.long, device=dev), (key[1:] != key[:-1]).nonzero().flatten() + 1]) + ks = key[starts] + run_lo = starts.tolist() + [n] + run_lev = (ks % nl).tolist() + run_slab = torch.div(ks, nl, rounding_mode="floor").tolist() + slab_runs = [[] for _ in range(ns)] + for r in range(len(run_lev)): + slab_runs[run_slab[r]].append((run_lo[r], run_lo[r + 1], run_lev[r])) + + def splat(lo, hi, ox, oy): # -> pixel idx (m,M), alpha (m,M); weight = 3D Gaussian peak along each pixel's ray + cols = pstack[:, lo:hi, None].unbind(0) + cxr_, cyr_, a00, a11, a22, b01, b02, b12, opa = cols[:9] # a* = Si components; b* = 2 * cross terms + px = cxr_ + ox[None, :] + py = cyr_ + oy[None, :] + valid = (px >= 0) & (px < width) & (py >= 0) & (py < height) + if is_ortho: # parallel ray (0,0,1) from screen point (X, Y, 0); rz constant per splat + c02, c12, mx, my, mz = cols[9:] + rx = (px - cx0) / s - mx + ry = (py - cy0) / s - my + rz = -mz + a22rz = a22 * rz + inx = torch.addcmul(b02 * rz, a00, rx).addcmul_(b01, ry) # a00 rx + b01 ry + b02 rz + rSr = torch.addcmul(a22rz * rz, rx, inx).addcmul_(ry, torch.addcmul(b12 * rz, a11, ry)) + dsr = torch.addcmul(a22rz, c02, rx).addcmul_(c12, ry) + q = torch.addcdiv(rSr, dsr * dsr, a22.clamp_min(1e-12), value=-1).clamp_min_(0) + else: # perspective ray (dx,dy,1) through the camera origin + su0, su1, su2, mus = cols[9:] + dx, dy = (px - cx0) / f, (py - cy0) / f + dsid = torch.addcmul(a22, dx, torch.addcmul(b02, a00, dx)) # a22 + dx*(a00 dx + b02) + dsid = dsid.addcmul_(dy, torch.addcmul(b12, a11, dy)) # + dy*(a11 dy + b12) + dsid = dsid.addcmul_(b01 * dx, dy) # + (2 s01) dx dy + dsimu = torch.addcmul(su2, dx, su0).addcmul_(dy, su1) + q = torch.addcdiv(mus, dsimu * dsimu, dsid.clamp_min(1e-12), value=-1).clamp_min_(0) + alpha = (opa * torch.exp(-0.5 * q) * valid).clamp_(0, 0.999) + idx = py.long().clamp(0, height - 1) * width + px.long().clamp(0, width - 1) + return idx, alpha + + # Front-to-back compositing over the depth slabs set up above. Within a slab the accumulation is a pure + # sum (order-independent), so splats are grouped by kernel level and each level uses its own tight window. + sharp = sharpen != 1.0 # winner-take-more colour blend: dominant splat shows more + cacc = torch.zeros((flat, 3), device=dev) + trans = torch.ones((flat,), device=dev) + a_buf = torch.zeros((flat,), device=dev) # sum alpha -> colour/depth/normal weight (alpha-weighted mean) + tau_buf = torch.zeros((flat,), device=dev) # sum -ln(1-alpha) -> slab opacity = 1-prod(1-alpha) + crgb = torch.zeros((flat, 3), device=dev) # sum alpha^p * rgb -> slab colour + wbuf = torch.zeros((flat,), device=dev) if sharp else None # sum alpha^p -> colour normalizer (sharp only) + dacc = torch.zeros((flat,), device=dev) if need_depth else None # front-weighted depth + nacc = torch.zeros((flat, 3), device=dev) if need_normal else None # front-weighted camera-space normal + zslab = torch.zeros((flat,), device=dev) if need_depth else None + nslab = torch.zeros((flat, 3), device=dev) if need_normal else None + stale = 0 # consecutive fully-occluded slabs -> early-out + for si in range(ns): + runs = slab_runs[si] + if not runs: + continue + a_buf.zero_() + tau_buf.zero_() + crgb.zero_() + if sharp: + wbuf.zero_() + if need_depth: + zslab.zero_() + if need_normal: + nslab.zero_() + for r_lo, r_hi, li in runs: # contiguous same-kernel-level runs in this slab + ox, oy = grids[li] + ch = max(2048, 10_000_000 // ox.shape[0]) # splats/chunk, bounded by this level's kernel size + for lo in range(r_lo, r_hi, ch): + hi = min(lo + ch, r_hi) + idx, alpha = splat(lo, hi, ox, oy) + idx, af = idx.reshape(-1), alpha.reshape(-1) + a_buf.index_add_(0, idx, af) + tau_buf.index_add_(0, idx, (-torch.log1p(-alpha)).reshape(-1)) # -ln(1-alpha), correct opacity merge + apw = alpha.pow(sharpen) if sharp else alpha # bias colour toward the highest-alpha splat + crgb.index_add_(0, idx, (apw[:, :, None] * rgb[lo:hi, None, :]).reshape(-1, 3)) + if sharp: + wbuf.index_add_(0, idx, apw.reshape(-1)) + if need_depth: + zslab.index_add_(0, idx, (alpha * zc_o[lo:hi, None]).reshape(-1)) + if need_normal: + nslab.index_add_(0, idx, (alpha[:, :, None] * nrm_o[lo:hi, None, :]).reshape(-1, 3)) + slab_a = 1 - torch.exp(-tau_buf) # 1 - prod(1-alpha): true opacity of the slab's splats + front = trans * slab_a + denom = wbuf if sharp else a_buf + cacc.addcmul_(front[:, None], crgb / denom.clamp_min(1e-8)[:, None]) # cacc += front * (crgb/denom) + if need_depth or need_normal: + ainv = a_buf.clamp_min(1e-8) # alpha-weighted-mean normalizer (depth/normal only) + if need_depth: + dacc.addcmul_(front, zslab / ainv) + if need_normal: + nacc.addcmul_(front[:, None], nslab / ainv[:, None]) + trans.mul_(1 - slab_a) + if si % 8 == 7: # checkpoint every 8 slabs (a per-slab GPU sync would cost more) + if float(front.max()) < 1e-3: # this checkpoint slab is fully occluded by what is in front + stale += 1 + if stale >= 2: # two occluded checkpoints running -> the rest are too -> stop + break + else: + stale = 0 + + cov = 1 - trans + covg = cov.reshape(height, width) + covm = covg > 0.5 if render_style in ("depth", "normal") else None # silhouette mask (depth/normal styles only) + depth_map = (dacc / cov.clamp_min(1e-6)).reshape(height, width) if need_depth else None + nrm_map = None + if need_normal: + # Per-splat surfel normals are jittery, so do a masked blur + nb = nacc.reshape(height, width, 3).permute(2, 0, 1)[None] + cb = cov.reshape(1, 1, height, width) + nb, cb = _gauss_blur(nb, 1.2, dev), _gauss_blur(cb, 1.2, dev) + normal = (nb / cb.clamp_min(1e-6))[0].permute(1, 2, 0) + nrm_map = normal / normal.norm(dim=-1, keepdim=True).clamp_min(1e-6) + + if render_style == "depth": # near = bright, far = dark, 0 off-object + d = torch.zeros(height, width, device=dev) + if bool(covm.any()): + lo, hi = depth_map[covm].min(), depth_map[covm].max() + d = torch.where(covm, ((hi - depth_map) / (hi - lo).clamp_min(1e-6)).clamp(0, 1), d) + img = d[:, :, None].expand(height, width, 3) + elif render_style == "normal": # OpenGL normal map: +X right, +Y up, +Z to viewer + enc = (nrm_map * t([1.0, -1.0, -1.0]) * 0.5 + 0.5).clamp(0, 1) + img = enc * covm[:, :, None] + else: # color / clay + img = cacc.reshape(height, width, 3) + if render_style == "clay": # studio key light + ambient -> sculpted matte look + kl = t([-0.4, -0.7, -0.6]) # key from screen upper-left, angled toward the viewer + kl = kl / kl.norm() + hl = (0.5 * (nrm_map * kl).sum(-1) + 0.5).clamp(0, 1) # half-Lambert: soft terminator, no harsh dark side + img = img * (0.35 + 0.65 * hl * hl)[:, :, None] # ambient floor + diffuse key + elif headlight_shading > 0: # camera headlight: darken faces turned from view + k = float(headlight_shading) + ndotl = (-nrm_map[:, :, 2]).clamp(0, 1) + img = img * (1 - 0.6 * k + 0.6 * k * ndotl)[:, :, None] + img = img.addcmul_(trans.reshape(height, width, 1), bg_comp) + if do_linear: # back to display space after linear compositing + img = _linear_to_srgb(img) + return img.clamp(0, 1).to(idev, idtype), covg.clamp(0, 1).to(idev, idtype) + + +class RenderSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RenderSplat", + display_name="Render Splat", + search_aliases=["splat to image", "render splat", "gaussian turntable"], + category="3d/splat", + description="Render a gaussian splat as an image with an anisotropic EWA rasterizer (oriented " + "elliptical splats, antialiased, depth-sorted front-to-back). The camera comes from a " + "camera_info input (Load / Preview 3D, or a Create Camera Info node); leave it empty to " + "auto-frame the splat. Set frames greater than 1 for a turntable batch of images to feed a Video node.", + inputs=[ + IO.Splat.Input("splat"), + IO.Int.Input("width", default=1024, min=64, max=2048, step=8), + IO.Int.Input("height", default=1024, min=64, max=2048, step=8), + IO.Int.Input("frames", default=1, min=-240, max=240, + tooltip="-1, 0, 1 = single still image; >1 = turntable, the camera orbits over a full " + "360 turn (works with any camera_info). Negative value orbits the other way."), + IO.Float.Input("splat_scale", default=1.0, min=0.1, max=5.0, step=0.05, advanced=True, + tooltip="Multiplier on each splat's projected footprint (lower = crisper points, " + "higher = softer/fuller surface)."), + IO.Float.Input("sharpen", default=2.0, min=1.0, max=8.0, step=0.5, + tooltip="Sharpen overlapping splats: 1.0 = physically-correct blend; higher biases " + "each pixel toward its dominant (nearest) splat for crisper texture, without " + "shrinking splats or opening gaps. Non-physical above 1."), + IO.Float.Input("headlight_shading", default=0.0, min=0.0, max=3.0, step=0.05, advanced=True, + tooltip="Diffuse shading from a light at the camera (headlight), using the splat surfel " + "normals: darkens surfaces that turn away from view to reveal form/curvature. " + "0 = flat albedo, 1 = strongest shading."), + IO.Float.Input("opacity_threshold", default=0.0, min=0.0, max=1.0, step=0.01, advanced=True, + tooltip="Cull gaussians with opacity below this (removes faint floaters)."), + IO.Combo.Input("render_style", options=["color", "clay", "depth", "normal"], + tooltip="What the image output shows: color, clay (neutral-albedo shaded), " + "depth (near=bright), normal (OpenGL normal map)."), + IO.Color.Input("background", default="#000000"), + IO.Image.Input("bg_image", optional=True, + tooltip="Optional background plate composited behind the splat (overrides the solid " + "background colour). Resized to the render size; a batch is used per frame, " + "a single image for all. color/clay only."), + IO.Load3DCamera.Input("camera_info", optional=True, + tooltip="Camera to render from - a Load3D / Preview3D camera or a Create Camera " + "Info node. If empty, the splat is auto-framed from a default 3/4 view."), + ], + outputs=[IO.Image.Output(display_name="image"), IO.Mask.Output(display_name="mask")], + ) + + @classmethod + def execute(cls, splat, width, height, frames, splat_scale, sharpen, headlight_shading, + opacity_threshold, background, render_style, camera_info=None, bg_image=None) -> IO.NodeOutput: + bg = _hex_to_rgb(background) + bg_imgs = None + if bg_image is not None: # resize the plate(s) to the render size: (B,H,W,3) + bi = bg_image[... , :3].movedim(-1, 1) # (B,3,H,W) + bi = comfy.utils.common_upscale(bi, width, height, "bicubic", "disabled") + bg_imgs = bi.movedim(1, -1).clamp(0, 1) + n_frames = abs(int(frames)) or 1 # magnitude = frame count (0 -> single still) + orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction + imgs, masks = [], [] + device = comfy.model_management.get_torch_device() + total = splat.positions.shape[0] * n_frames + pbar = comfy.utils.ProgressBar(total) if total > 1 else None + k = 0 + for i in range(splat.positions.shape[0]): + xyz, rgb, opacity, scale, rot = _gaussian_item(splat, i, device) + if opacity_threshold > 0: + keep = opacity >= opacity_threshold + xyz, rgb, opacity, scale, rot = xyz[keep], rgb[keep], opacity[keep], scale[keep], rot[keep] + base_cam = camera_info + if base_cam is None: # no camera -> default 3/4 view, auto-framed on the splat + center = xyz.mean(0) if xyz.shape[0] else torch.zeros(3, device=device) + extent = (_quantile((xyz - center).norm(dim=-1), 0.99).clamp_min(1e-4) if xyz.shape[0] + else torch.tensor(1.0, device=device)) + dist = float(extent / (math.tan(math.radians(35.0) / 2) * 0.9)) + base_cam = _orbit_camera_info(35.0, 30.0, dist, 35.0, center, device) + for fr in range(n_frames): + cam_fr = (base_cam if n_frames == 1 + else _orbit_camera_info_yaw(base_cam, orbit_dir * 360.0 * fr / n_frames, device)) + bg_k = bg_imgs[k % bg_imgs.shape[0]] if bg_imgs is not None else bg # per-frame plate, or solid colour + img, mask = _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, bg_k, cam_fr, + sharpen=sharpen, headlight_shading=headlight_shading, + render_style=render_style) + imgs.append(img) + masks.append(mask) + k += 1 + if pbar is not None: + pbar.update(1) + return IO.NodeOutput(torch.stack(imgs), torch.stack(masks)) + + +class CreateCameraInfo(IO.ComfyNode): # TODO: move to better file + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="CreateCameraInfo", + display_name="Create Camera Info", + search_aliases=["camera position", "make camera info", "orbit camera", "look at camera"], + category="3d", + description="Build a camera_info" + "Mode 'orbit' aims with yaw/pitch/distance around the target; " + "'look_at' places the camera at world position. Coordinates are the viewer's world space (right-handed,Y-up).", + inputs=[ + IO.DynamicCombo.Input("mode", options=[ + IO.DynamicCombo.Option("orbit", [ + IO.Float.Input("yaw", default=35.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("pitch", default=30.0, min=-89.0, max=89.0, step=1.0), + IO.Float.Input("distance", default=4.0, min=0.01, max=1000.0, step=0.01, + tooltip="Camera distance from the target."), + ]), + IO.DynamicCombo.Option("look_at", [ + IO.Float.Input("position_x", default=4.0, min=-1000.0, max=1000.0, step=0.01, + tooltip="Camera position in world space (right-handed, Y-up)."), + IO.Float.Input("position_y", default=4.0, min=-1000.0, max=1000.0, step=0.01), + IO.Float.Input("position_z", default=4.0, min=-1000.0, max=1000.0, step=0.01), + ]), + IO.DynamicCombo.Option("quaternion", [ + IO.Float.Input("position_x", default=4.0, min=-1000.0, max=1000.0, step=0.01, + tooltip="Camera position in world space (right-handed, Y-up)."), + IO.Float.Input("position_y", default=4.0, min=-1000.0, max=1000.0, step=0.01), + IO.Float.Input("position_z", default=4.0, min=-1000.0, max=1000.0, step=0.01), + IO.Float.Input("quat_x", default=0.0, min=-1.0, max=1.0, step=0.001), + IO.Float.Input("quat_y", default=0.0, min=-1.0, max=1.0, step=0.001), + IO.Float.Input("quat_z", default=0.0, min=-1.0, max=1.0, step=0.001), + IO.Float.Input("quat_w", default=1.0, min=-1.0, max=1.0, step=0.001, + tooltip="Camera world-rotation quaternion (three.js: looks down local -Z). Normalized for you."), + ]), + ], tooltip="How to define the camera: orbit angles, an explicit position, or a position + quaternion."), + IO.Float.Input("target_x", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True, + tooltip="Look-at point (orbit pivot / aim). In orbit mode, move it to pan/translate the " + "whole camera. Ignored in quaternion mode. Defaults to the origin."), + IO.Float.Input("target_y", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True), + IO.Float.Input("target_z", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True), + IO.Float.Input("roll", default=0.0, min=-180.0, max=180.0, step=1.0, + tooltip="Camera roll about the view axis, degrees."), + IO.Float.Input("fov", default=35.0, min=1.0, max=120.0, step=1.0, + tooltip="Vertical field of view in degrees."), + IO.Float.Input("zoom", default=1.0, min=0.01, max=100.0, step=0.01, + tooltip="Digital zoom (focal-length multiplier). >1 zooms in without moving the camera."), + IO.Combo.Input("camera_type", options=["perspective", "orthographic"], + tooltip="Projection used by Render Splat: perspective (foreshortening) or orthographic (parallel)."), + ], + outputs=[IO.Load3DCamera.Output(display_name="camera_info")], + ) + + @classmethod + def execute(cls, mode, target_x, target_y, target_z, roll, fov, zoom=1.0, camera_type="perspective") -> IO.NodeOutput: + dev = comfy.model_management.get_torch_device() + kind = mode["mode"] + if kind == "quaternion": # explicit world position + camera rotation + position = [mode["position_x"], mode["position_y"], mode["position_z"]] + quat = [mode["quat_x"], mode["quat_y"], mode["quat_z"], mode["quat_w"]] + return IO.NodeOutput(_quat_camera_info(position, quat, fov, dev, zoom=zoom, camera_type=camera_type)) + target = [target_x, target_y, target_z] # orbit pivot / aim; move it to pan the whole camera + if kind == "orbit": # yaw/pitch/distance about the target (world Y-up) + y, p = math.radians(mode["yaw"]), math.radians(mode["pitch"]) + cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p) + d = mode["distance"] + position = [target_x + d * cp * sy, target_y + d * sp, target_z + d * cp * cy] + else: # look_at: explicit world-space camera position + position = [mode["position_x"], mode["position_y"], mode["position_z"]] + return IO.NodeOutput(_lookat_camera_info(position, target, fov, dev, zoom=zoom, camera_type=camera_type, roll=roll)) + + +class TransformSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TransformSplat", + display_name="Transform Splat", + search_aliases=["move splat", "rotate splat", "scale splat", "gaussian transform"], + category="3d/splat", + description="Translate, rotate, and scale a gaussian splat. " + "Non-uniform scale also reshapes every individual splat, slower process.", + inputs=[ + IO.Splat.Input("splat"), + IO.Float.Input("translate_x", default=0.0, min=-100.0, max=100.0, step=0.01), + IO.Float.Input("translate_y", default=0.0, min=-100.0, max=100.0, step=0.01), + IO.Float.Input("translate_z", default=0.0, min=-100.0, max=100.0, step=0.01), + IO.Float.Input("rotate_x", default=0.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("rotate_y", default=0.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("rotate_z", default=0.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("scale_x", default=1.0, min=0.01, max=100.0, step=0.01), + IO.Float.Input("scale_y", default=1.0, min=0.01, max=100.0, step=0.01), + IO.Float.Input("scale_z", default=1.0, min=0.01, max=100.0, step=0.01), + ], + outputs=[IO.Splat.Output(display_name="splat")], + ) + + @classmethod + def execute(cls, splat, translate_x, translate_y, translate_z, + rotate_x, rotate_y, rotate_z, scale_x, scale_y, scale_z) -> IO.NodeOutput: + pos = splat.positions + dev, dt = pos.device, pos.dtype + q_rot = _euler_to_quat(rotate_x, rotate_y, rotate_z).to(device=dev, dtype=dt) + R = _quat_to_mat(q_rot[None])[0] # (3, 3) node rotation + D = torch.tensor([scale_x, scale_y, scale_z], dtype=dt, device=dev) + A = D[:, None] * R # diag(D) @ R: per-axis scale after rotation + t = torch.tensor([translate_x, translate_y, translate_z], dtype=dt, device=dev) + + positions = pos @ A.T + t # rotate, scale per-axis, then translate + if scale_x == scale_y == scale_z: # uniform: rotation/scale factor out cleanly + scales = splat.scales * scale_x + rotations = _quat_mul(q_rot.expand_as(splat.rotations), splat.rotations) + rotations = rotations / rotations.norm(dim=-1, keepdim=True).clamp_min(1e-12) + else: # non-uniform: transform Sigma = A R s^2 R^T A^T, re-extract + rg = _quat_to_mat(splat.rotations.reshape(-1, 4)) # (M,3,3) per-splat rotation + s2 = splat.scales.reshape(-1, 3).square() + cov = (rg * s2[:, None, :]) @ rg.transpose(-1, -2) # Sigma + cov = A @ cov @ A.T # A Sigma A^T (A broadcast over splats) + lam, V = torch.linalg.eigh(cov) # symmetric -> eigenvalues (asc), orthonormal axes + V = V * torch.where(torch.linalg.det(V) < 0, -1.0, 1.0)[..., None, None] # keep a proper rotation + scales = lam.clamp_min(0).sqrt().reshape(splat.scales.shape) + rotations = _mat_to_quat(V).reshape(splat.rotations.shape) + out = Types.SPLAT(positions, scales, rotations, splat.opacities, splat.sh, + counts=getattr(splat, "counts", None)) + return IO.NodeOutput(out) + + +class GetSplatCount(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GetSplatCount", + display_name="Get Splat Count", + search_aliases=["splat count", "gaussian count", "number of splats", "splat info"], + category="3d/splat", + description="Returns the number of splats summed across the batch.", + inputs=[IO.Splat.Input("splat")], + outputs=[IO.Splat.Output(display_name="splat"), + IO.Int.Output(display_name="count"), + ], + hidden=[IO.Hidden.unique_id], + ) + + @classmethod + def execute(cls, splat) -> IO.NodeOutput: + count = sum(_real_len(splat, i) for i in range(splat.positions.shape[0])) + if cls.hidden.unique_id: # show the count inline on the node + PromptServer.instance.send_progress_text(f"{count:,} splats", cls.hidden.unique_id) + return IO.NodeOutput(splat, count) + + +def _pad_stack(items, n): + # Stack a list of (Lᵢ, *tail) tensors into (B, n, *tail), zero-padding each row up to n. + tail = items[0].shape[1:] + out = items[0].new_zeros((len(items), n, *tail)) + for i, t in enumerate(items): + out[i, :t.shape[0]] = t + return out + + +def _merge_gaussians(gaussians: list) -> Types.SPLAT: + # Concatenate SPLAT batches along the splat dimension (per item), padding SH to the highest degree. + gs = [g for g in gaussians if g is not None] + if not gs: + raise ValueError("MergeSplat: no gaussians to merge") + b = gs[0].positions.shape[0] + for g in gs: + if g.positions.shape[0] != b: + raise ValueError(f"MergeSplat: batch size mismatch ({b} vs {g.positions.shape[0]}).") + max_k = max(g.sh.shape[2] for g in gs) + + pos_b, scl_b, rot_b, op_b, sh_b, lengths = [], [], [], [], [], [] + for i in range(b): + pos_i, scl_i, rot_i, op_i, sh_i = [], [], [], [], [] + for g in gs: + end = _real_len(g, i) + pos_i.append(g.positions[i, :end]) + scl_i.append(g.scales[i, :end]) + rot_i.append(g.rotations[i, :end]) + op_i.append(g.opacities[i, :end]) + sh = g.sh[i, :end] # (end, K, 3) + if sh.shape[1] < max_k: # zero-pad lower-degree SH + sh = torch.cat([sh, sh.new_zeros(sh.shape[0], max_k - sh.shape[1], sh.shape[2])], dim=1) + sh_i.append(sh) + pos_b.append(torch.cat(pos_i)) + scl_b.append(torch.cat(scl_i)) + rot_b.append(torch.cat(rot_i)) + op_b.append(torch.cat(op_i)) + sh_b.append(torch.cat(sh_i)) + lengths.append(pos_b[-1].shape[0]) + + n = max(lengths) + counts = None + if len(set(lengths)) > 1: + counts = torch.tensor(lengths, device=gs[0].positions.device, dtype=torch.int64) + return Types.SPLAT(_pad_stack(pos_b, n), _pad_stack(scl_b, n), _pad_stack(rot_b, n), + _pad_stack(op_b, n), _pad_stack(sh_b, n), counts=counts) + + +class MergeSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + # Autogrow: a splat0/splat1/... input list that grows a fresh slot as you connect splats. + splats = IO.Autogrow.TemplatePrefix(IO.Splat.Input("splat"), prefix="splat", min=2, max=32) + return IO.Schema( + node_id="MergeSplat", + display_name="Merge Splats", + search_aliases=["union splat", "densify gaussian", "combine splat", "merge gaussian"], + category="3d/splat", + description="Concatenate any number of gaussian splats into one. Unioning several decodes of the same " + "latent at different seeds densifies the surface, this can improve surface quality when meshing.", + inputs=[IO.Autogrow.Input("splats", template=splats)], + outputs=[IO.Splat.Output(display_name="splat")], + ) + + @classmethod + def execute(cls, splats: IO.Autogrow.Type) -> IO.NodeOutput: + gs = [v for v in splats.values() if v is not None] + if not gs: + raise ValueError("MergeSplat: connect at least one splat.") + return IO.NodeOutput(_merge_gaussians(gs)) + + +def _inverse_covariance(scale, quat): + # Per-splat Sigma^-1 = R diag(1/s^2) R^T. scale (N,3) linear std, quat (N,4) wxyz -> (N,3,3). + q = quat / quat.norm(dim=1, keepdim=True).clamp_min(1e-12) + w, x, y, z = q.unbind(-1) + R = torch.stack([ + 1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y), + 2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x), + 2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y), + ], dim=1).reshape(-1, 3, 3) + inv_s2 = 1.0 / scale.clamp_min(1e-8) ** 2 # (N, 3) + return torch.einsum("nij,nj,nkj->nik", R, inv_s2, R) + + +def _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, color_sharpen=1.0, chunk=4096, progress=None, + col_dtype=torch.float16): + # Splat each gaussian as its oriented-covariance disk (3-sigma, opacity-weighted) into a density grid, + # plus a colour volume. Each gaussian uses a voxel window sized to its OWN 3-sigma (capped at `kernel`). + # Colour is weighted by w^color_sharpen: >1 biases each voxel toward its dominant gaussian (crisper + # texture). Returns (density, colour numerator, colour normaliser, origin, voxel). + pad = 4.0 * scale.median() + lo = xyz.amin(0) - pad + hi = xyz.amax(0) + pad + voxel = ((hi - lo).max() / res).clamp_min(1e-8) + dx, dy, dz = (torch.ceil((hi - lo) / voxel).long() + 1).tolist() + + sinv = _inverse_covariance(scale, quat) + kreq = torch.ceil(3.0 * scale.amax(-1) / voxel).long().clamp(1, int(kernel)) # per-gaussian half-width + sharp = color_sharpen != 1.0 + vol = torch.zeros(dx * dy * dz, device=device) # Sum(w) density (surface) + colvol = torch.zeros(dx * dy * dz, 3, device=device, dtype=col_dtype) # Sum(w^p * rgb) colour numerator + wcol = torch.zeros(dx * dy * dz, device=device, dtype=col_dtype) if sharp else None # Sum(w^p) normaliser (p>1) + n, done = xyz.shape[0], 0 + for k in range(1, int(kernel) + 1): + sel = (kreq == k).nonzero(as_tuple=True)[0] + if sel.numel() == 0: + continue + rng = torch.arange(-k, k + 1, device=device, dtype=torch.float32) + off = torch.stack(torch.meshgrid(rng, rng, rng, indexing="ij"), -1).reshape(-1, 3) # (M, 3) + for st in range(0, sel.numel(), chunk): + gi = sel[st:st + chunk] + cc = xyz[gi] + idx = ((cc - lo) / voxel).round()[:, None, :] + off[None] # (b, M, 3) voxel coords + d = (lo + idx * voxel) - cc[:, None, :] # world offset to voxel center + quad = torch.einsum("bmi,bij,bmj->bm", d, sinv[gi], d) + wgt = opacity[gi, None] * torch.exp(-0.5 * quad) + wgt = torch.where(quad < 9.0, wgt, torch.zeros_like(wgt)) # clip beyond 3 sigma + ii = idx.long() + ix = ii[..., 0].clamp(0, dx - 1) + iy = ii[..., 1].clamp(0, dy - 1) + iz = ii[..., 2].clamp(0, dz - 1) + flat = (ix * (dy * dz) + iy * dz + iz).reshape(-1) + vol.index_add_(0, flat, wgt.reshape(-1)) + wp = wgt.pow(color_sharpen) if sharp else wgt # winner-take-more colour weight + colvol.index_add_(0, flat, (wp[..., None] * rgb[gi, None, :]).reshape(-1, 3).to(col_dtype)) + if sharp: + wcol.index_add_(0, flat, wp.reshape(-1).to(col_dtype)) + done += gi.numel() + if progress is not None: + progress(min(1.0, done / max(1, n))) + colnorm = (wcol if sharp else vol).reshape(dx, dy, dz) # p==1 -> Sum(w) == density + return vol.reshape(dx, dy, dz), colvol.reshape(dx, dy, dz, 3), colnorm, lo.cpu().numpy(), float(voxel) + + +def _connected_components_gpu(faces, nv): + # FastSV connected components: grandparent hooking + shortcutting, ~O(log nv) iterations. + # Returns per-vertex component labels (min node id, not densified). + a = torch.cat([faces[:, 0], faces[:, 1]]) # 2F edge endpoints: (v0,v1),(v1,v2) + b = torch.cat([faces[:, 1], faces[:, 2]]) + f = torch.arange(nv, device=faces.device) + while True: + gp = f[f] # grandparent + ga, gb = gp[a], gp[b] + new = f.clone() + new.scatter_reduce_(0, f[a], gb, "amin", include_self=True) # stochastic hooking onto roots + new.scatter_reduce_(0, f[b], ga, "amin", include_self=True) + new.scatter_reduce_(0, a, gb, "amin", include_self=True) # aggressive hooking, both directions + new.scatter_reduce_(0, b, ga, "amin", include_self=True) + new = new[new] # shortcut (path compression) + if torch.equal(new, f): + return f + f = new + + +def _clean_components_gpu(verts, faces, min_verts, device): + # GPU port of _clean_components: FastSV components + scatter reductions. Byte-identical to the numpy path + vt = torch.as_tensor(verts, device=device) + ft = torch.as_tensor(faces, device=device) + nv = vt.shape[0] + _, label = torch.unique(_connected_components_gpu(ft, nv), return_inverse=True) # dense 0..ncomp-1 + ncomp = int(label.max()) + 1 + flabel = label[ft[:, 0]] # component id per face + keep = torch.bincount(label, minlength=ncomp) >= min_verts # per-component vertex-count gate + if int(keep.sum()) > 1: + fcount = torch.bincount(flabel, minlength=ncomp) + largest = int(torch.where(keep, fcount, fcount.new_tensor(-1)).argmax()) + v0, v1, v2 = vt[ft[:, 0]], vt[ft[:, 1]], vt[ft[:, 2]] + cvol = torch.zeros(ncomp, device=device).scatter_add_(0, flabel, (v0 * torch.linalg.cross(v1, v2)).sum(-1)) + idx3 = label[:, None].expand(-1, 3) # per-component vertex bbox + cmin = torch.full((ncomp, 3), float("inf"), device=device).scatter_reduce_(0, idx3, vt, "amin", include_self=True) + cmax = torch.full((ncomp, 3), float("-inf"), device=device).scatter_reduce_(0, idx3, vt, "amax", include_self=True) + tol = 1e-4 * (cmax[largest] - cmin[largest]).max() + enclosed = (cmin >= cmin[largest] - tol).all(1) & (cmax <= cmax[largest] + tol).all(1) + inner = enclosed & (torch.sign(cvol) != torch.sign(cvol[largest])) & (torch.arange(ncomp, device=device) != largest) + keep &= ~inner + faces_k = ft[keep[flabel]] + if faces_k.shape[0] == 0: + return verts[:0], faces[:0] + used = torch.unique(faces_k) # sorted, matches np.unique + remap = torch.full((nv,), -1, dtype=torch.int64, device=device) + remap[used] = torch.arange(used.shape[0], device=device) + return vt[used].cpu().numpy(), remap[faces_k].cpu().numpy() + + +def _clean_components(verts, faces, min_verts, device=None): + # Drop floaters (components with < min_verts vertices) and inner shells - the surfel shell density + # extracts a double wall (outer + inner cavity surface). GPU path (FastSV CC + scatter reductions, ~13x + # faster) when an accelerator has headroom; else numpy/scipy. Both produce byte-identical output. + if device is not None and not comfy.model_management.is_device_cpu(device) and \ + comfy.model_management.get_free_memory(device) > 10 * faces.size * 8: # peak ~8.4x faces bytes + return _clean_components_gpu(verts, faces, min_verts, device) + nv = len(verts) + e = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [0, 2]]], 0) + ncomp, label = connected_components(coo_matrix((np.ones(len(e)), (e[:, 0], e[:, 1])), shape=(nv, nv)), directed=False) + flabel = label[faces[:, 0]] # component id per face + keep = np.bincount(label, minlength=ncomp) >= min_verts # per-component vertex-count gate + if keep.sum() > 1: + fcount = np.bincount(flabel, minlength=ncomp) + largest = np.where(keep, fcount, -1).argmax() + v0, v1, v2 = verts[faces[:, 0]], verts[faces[:, 1]], verts[faces[:, 2]] + cvol = np.bincount(flabel, weights=np.einsum("ij,ij->i", v0, np.cross(v1, v2)), minlength=ncomp) # 6*signed vol + cidx = np.arange(ncomp) # per-component vertex bbox via ndimage (~6x faster than ufunc.at) + cmin = np.stack([_ndi_minimum(verts[:, a], label, cidx) for a in range(3)], 1) + cmax = np.stack([_ndi_maximum(verts[:, a], label, cidx) for a in range(3)], 1) + tol = 1e-4 * (cmax[largest] - cmin[largest]).max() + enclosed = (cmin >= cmin[largest] - tol).all(1) & (cmax <= cmax[largest] + tol).all(1) + inner = enclosed & (np.sign(cvol) != np.sign(cvol[largest])) & (np.arange(ncomp) != largest) + keep &= ~inner + faces = faces[keep[flabel]] + if len(faces) == 0: + return verts[:0], faces + used = np.unique(faces) + remap = np.full(nv, -1, np.int64) + remap[used] = np.arange(len(used)) + return verts[used], remap[faces] + + +def _surface_nets(vol, level, voxel, origin, device): + # Vectorized Surface Nets: one dual vertex per sign-changing cell at its edge-crossing mean, quads wound CCW-outward. + # Returns verts (V,3), faces (F,3). + vol = vol.to(device=device, dtype=torch.float32) + dx, dy, dz = vol.shape + origin_t = torch.as_tensor(origin, device=device, dtype=torch.float32) + empty = (np.zeros((0, 3), np.float32), np.zeros((0, 3), np.int64)) + if dx < 2 or dy < 2 or dz < 2: + return empty + + # Active = cells whose 8 corners aren't all in/all out. + inside = vol >= level # (dx,dy,dz) bool + cs8 = [inside[ox:ox + dx - 1, oy:oy + dy - 1, oz:oz + dz - 1] + for ox, oy, oz in ((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0), + (0, 0, 1), (1, 0, 1), (0, 1, 1), (1, 1, 1))] + any_in = cs8[0] | cs8[1] | cs8[2] | cs8[3] | cs8[4] | cs8[5] | cs8[6] | cs8[7] + all_in = cs8[0] & cs8[1] & cs8[2] & cs8[3] & cs8[4] & cs8[5] & cs8[6] & cs8[7] + active = any_in & ~all_in # (cx,cy,cz) straddling cells + nv = int(active.sum()) + if nv == 0: + return empty + + # Active cells only (a thin shell): each dual vertex = mean of its 12 edges' zero-crossings. + del any_in, all_in, cs8 # corner bool grids no longer needed + ac = active.nonzero(as_tuple=False) # (nv,3) cell min-corner indices + offs = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], + [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]], device=device) + offf = offs.to(torch.float32) + edges = torch.tensor([[0, 1], [0, 2], [0, 4], [1, 3], [1, 5], [2, 3], + [2, 6], [3, 7], [4, 5], [4, 6], [5, 7], [6, 7]], device=device) + e0, e1 = edges[:, 0], edges[:, 1] + oe0, oe1 = offf[e0], offf[e1] # (12,3) edge endpoints + + cstep = 1 << 18 # chunk to bound peak memory (CPU RAM too) + loc = [] + for st in range(0, nv, cstep): + ci = ac[st:st + cstep, None, :] + offs[None] # (m,8,3) + cval = vol[ci[..., 0], ci[..., 1], ci[..., 2]] # (m,8) corner values + csl = cval >= level + v0, v1 = cval[:, e0], cval[:, e1] # (m,12) + cross = (csl[:, e0] != csl[:, e1])[..., None].to(torch.float32) + denom = v1 - v0 + t = torch.where(denom.abs() > 1e-12, (level - v0) / denom, torch.full_like(denom, 0.5)).clamp(0, 1) + pts = torch.lerp(oe0, oe1, t[..., None]) # (m,12,3) local crossings (fused interp) + loc.append((pts * cross).sum(1) / cross.sum(1).clamp_min(1.0)) # (m,3) in [0,1] + local = torch.cat(loc, 0) if len(loc) > 1 else loc[0] # (nv,3) + verts = origin_t + (ac.to(torch.float32) + local) * voxel # world space + del loc, local, ac + + vid = torch.full((dx - 1, dy - 1, dz - 1), -1, dtype=torch.int32, device=device) + vid[active] = torch.arange(nv, dtype=torch.int32, device=device) + del active + + # Each straddling grid edge -> one quad from its 4 cells; `sol` (low-end sign) picks outward winding. + faces = [] + + def emit(cr, sol, a, b, d, c): + valid = cr & (a >= 0) & (b >= 0) & (c >= 0) & (d >= 0) + if not bool(valid.any()): + return + a, b, c, d, sol = a[valid], b[valid], c[valid], d[valid], sol[valid] + p2, p4 = torch.where(sol, b, c), torch.where(sol, c, b) # reverse quad winding where ~sol + faces.append(torch.stack([a, p2, d], 1)) + faces.append(torch.stack([a, d, p4], 1)) + + a = inside[0:dx - 1, 1:dy - 1, 1:dz - 1] + emit(a != inside[1:dx, 1:dy - 1, 1:dz - 1], a, + vid[:, 0:dy - 2, 0:dz - 2], vid[:, 1:dy - 1, 0:dz - 2], + vid[:, 1:dy - 1, 1:dz - 1], vid[:, 0:dy - 2, 1:dz - 1]) + a = inside[1:dx - 1, 0:dy - 1, 1:dz - 1] + emit(a != inside[1:dx - 1, 1:dy, 1:dz - 1], a, + vid[0:dx - 2, :, 0:dz - 2], vid[0:dx - 2, :, 1:dz - 1], + vid[1:dx - 1, :, 1:dz - 1], vid[1:dx - 1, :, 0:dz - 2]) + a = inside[1:dx - 1, 1:dy - 1, 0:dz - 1] + emit(a != inside[1:dx - 1, 1:dy - 1, 1:dz], a, + vid[0:dx - 2, 0:dy - 2, :], vid[1:dx - 1, 0:dy - 2, :], + vid[1:dx - 1, 1:dy - 1, :], vid[0:dx - 2, 1:dy - 1, :]) + + if not faces: + return empty + return verts.cpu().numpy().astype(np.float32), torch.cat(faces, 0).cpu().numpy().astype(np.int64) + + +def _otsu_level(values, bins=256): + # Otsu threshold: the density value that best splits inside/outside (max between-class variance). + hist, edges = np.histogram(values, bins=bins) + hist = hist.astype(np.float64) + centers = (edges[:-1] + edges[1:]) * 0.5 + w = np.cumsum(hist) # background-class weight at each split + mu = np.cumsum(hist * centers) + wf = w[-1] - w # foreground-class weight + mb = mu / np.where(w > 0, w, 1.0) + mf = (mu[-1] - mu) / np.where(wf > 0, wf, 1.0) + var_b = w * wf * (mb - mf) ** 2 # between-class variance + var_b[(w <= 0) | (wf <= 0)] = -1.0 + return float(centers[int(np.argmax(var_b))]) + + +def _taubin_smooth(verts, faces, iters, lam=0.5, mu=-0.53): + # Taubin lambda|mu smoothing: low-pass the mesh surface without the shrinkage of a Laplacian blur + # (the mu inflation pass cancels the lambda pass's volume loss). Uniform (umbrella) weights. + if iters <= 0 or len(verts) == 0 or len(faces) == 0: + return verts + nv = len(verts) + e = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [0, 2]]], 0) + e = np.concatenate([e, e[:, ::-1]], 0) # symmetric adjacency + adj = coo_matrix((np.ones(len(e), np.float32), (e[:, 0], e[:, 1])), shape=(nv, nv)).tocsr() + adj.data[:] = 1.0 + deg = np.clip(np.asarray(adj.sum(1)).ravel(), 1.0, None).astype(np.float32)[:, None] + v = verts.astype(np.float32) # fp32 matvec: ~2x faster, sub-micron drift on unit-scale verts + for _ in range(int(iters)): + for fac in (lam, mu): + v = v + np.float32(fac) * ((adj @ v) / deg - v) # fac * (mean(neighbours) - v) + return np.ascontiguousarray(v) + + +def _sample_vertex_colours_gpu(colvol, colnorm, verts, origin, voxel, device): + # GPU trilinear sampling of the colour numerator (3ch) and normaliser (1ch) at vertex grid-coords + # reproduces scipy map_coordinates(order=1, mode='nearest'). Returns col (V,3) numpy. + dx, dy, dz = colnorm.shape + vt = torch.as_tensor(verts, device=device, dtype=torch.float32) + org = torch.as_tensor(origin, device=device, dtype=torch.float32) + gi = (vt - org) / voxel # (V,3) grid-index coords (x,y,z) + size = torch.tensor([dx, dy, dz], device=device, dtype=torch.float32) + g = 2.0 * gi / (size - 1).clamp_min(1.0) - 1.0 # -> [-1,1] (align_corners) + grid = torch.stack([g[:, 2], g[:, 1], g[:, 0]], -1)[None, None, None] # (1,1,1,V,3): grid_sample order (W=z,H=y,D=x) + + def samp(v): # (dx,dy,dz,C) cpu fp16 -> (C,V) fp32 on device + inp = v.to(device).permute(3, 0, 1, 2)[None].float() + o = torch.nn.functional.grid_sample(inp, grid, mode="bilinear", padding_mode="border", align_corners=True) + return o[0, :, 0, 0, :] + num = samp(colvol) # (3,V) + den = samp(colnorm[..., None]) # (1,V) + return (num / den.clamp_min(1e-8)).T.cpu().numpy() # (V,3) + + +def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_component, min_opacity, color_sharpen, device, progress=None): + # Mesh one splat: density + colour grids -> Surface Nets -> floater removal -> Taubin smoothing -> + # volume-sampled colours. Returns (verts, faces int64, colors in [0,1]), or None if no surface. + rep = progress if progress is not None else (lambda *_: None) + + end = _real_len(g, i) + xyz = g.positions[i, :end].to(device=device, dtype=torch.float32) + scale = g.scales[i, :end].to(device=device, dtype=torch.float32) + quat = g.rotations[i, :end].to(device=device, dtype=torch.float32) + opacity = g.opacities[i, :end].reshape(-1).to(device=device, dtype=torch.float32) + rgb = (g.sh[i, :end, 0, :].to(device=device, dtype=torch.float32) * _C0 + 0.5).clamp(0, 1) + + keep = opacity >= min_opacity + xyz, scale, quat, opacity, rgb = xyz[keep], scale[keep], quat[keep], opacity[keep], rgb[keep] + if xyz.shape[0] == 0: + return None + + vol, colvol, colnorm, origin, voxel = _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, + color_sharpen=color_sharpen, + progress=lambda f: rep(0.25 * f)) # density build: 0 -> 25% + # Colour: sample on the GPU (grid_sample) when there's headroom + colour_gpu = not comfy.model_management.is_device_cpu(device) and comfy.model_management.get_free_memory(device) > 6 * vol.numel() * 4 + if colour_gpu: + colvol_cpu, colnorm_cpu = colvol.cpu(), colnorm.half().cpu() # park colours (fp16) off-GPU during meshing + colvol_np = colnorm_np = None + else: + colvol_np = colvol.cpu().numpy().astype(np.float32) # Sum(w^p * rgb) colour numerator (fp16 grid -> fp32) + colnorm_np = colnorm.cpu().numpy().astype(np.float32) # Sum(w^p) colour normaliser + del colvol, colnorm # free the colour grids before iso-surfacing + rep(0.40) + + vmin, vmax = float(vol.min()), float(vol.max()) + occ = vol[vol > vmax * 1e-3] # occupied voxels (skip the empty-space peak) + if occ.numel() == 0: + return None + # Otsu picks the inside/outside split principledly; `level_bias` nudges it (1.0 = auto). Clamp strictly + # inside the data range so a bias can't push the iso off the histogram. + level = min(max(_otsu_level(occ.cpu().numpy()) * level_bias, vmin + 1e-6 * (vmax - vmin)), + vmax - 1e-6 * (vmax - vmin)) + + # Iso-surface on the accelerator when there's headroom: ~15x faster than CPU, identical output. Chunked + # Surface Nets peaks at ~3-3.5x the density grid, so fall back to CPU for large grids / tight VRAM. + sn_dev = device + if not comfy.model_management.is_device_cpu(device) and comfy.model_management.get_free_memory(device) < 6 * vol.numel() * 4: + sn_dev = torch.device("cpu") + vol = vol.cpu() + verts, faces = _surface_nets(vol, level, voxel, origin, sn_dev) + del vol + rep(0.55) + if min_component > 0 and len(faces) > 0: + verts, faces = _clean_components(verts, faces, min_component, device) + if len(verts) == 0 or len(faces) == 0: + return None + + # Taubin smooths the blocky iso without shrinking it (unlike blurring the density, which rounds features). + verts = _taubin_smooth(verts, faces, taubin) + rep(0.7) + + # Colour each vertex from the co-splatted colour volume: trilinearly sample the numerator Sum(w^p*rgb) + # and normaliser Sum(w^p) separately, then divide. Normalising AFTER interpolation keeps zero-density + # edge voxels from pulling colours toward black, and matches the gaussians that formed the surface. + if colour_gpu: + col = _sample_vertex_colours_gpu(colvol_cpu, colnorm_cpu, verts, origin, voxel, device) + else: + coords = ((verts - origin) / voxel).T # (3, V) grid-index coords, matching volume axes + num = np.stack([map_coordinates(colvol_np[..., c], coords, order=1, mode="nearest") for c in range(3)], -1) + den = map_coordinates(colnorm_np, coords, order=1, mode="nearest") + col = num / np.clip(den, 1e-8, None)[:, None] + rep(1.0) + + # The unlit material's COLOR_0 is linear and the viewer sRGB-encodes it on output; the splat colours + # are display (sRGB) values, so convert sRGB -> linear here to land at the same brightness as the splat. + col = np.clip(col, 0, 1) + col = np.where(col <= 0.04045, col / 12.92, ((col + 0.055) / 1.055) ** 2.4).astype(np.float32) + + # Splat +Y is glTF's -Y: rotate 180 deg about X (negate Y,Z) to land upright. Proper rotation, so + # winding is kept; done after colouring (which works in the splat frame). + verts = np.ascontiguousarray(verts * np.array([1.0, -1.0, -1.0], dtype=np.float32)) + return (torch.from_numpy(verts), torch.from_numpy(faces), torch.from_numpy(col)) + + +class SplatToMesh(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SplatToMesh", + display_name="Extract Mesh from Splat", + search_aliases=["splat to mesh", "gaussian surface nets", "splat surface", "mesh splat"], + category="3d/splat", + description="Extract a coloured mesh from a gaussian splat.", + inputs=[ + IO.Splat.Input("splat"), + IO.Int.Input("resolution", default=384, min=64, max=768, step=16, + tooltip="Density-grid resolution along the longest axis. Higher = finer surface, " + "more VRAM/time (grows with resolution^3)."), + IO.Int.Input("kernel", default=5, min=1, max=8, + tooltip="Max splat half-width in voxels. Each gaussian is rasterized over a window " + "sized to its own 3-sigma, capped here - small surfels stay cheap, large ones " + "aren't truncated. Raise if sparse splats leave gaps."), + IO.Int.Input("smooth", default=0, min=0, max=60, advanced = True, + tooltip="Taubin mesh-smoothing iterations. Smooths the surface without shrinking it " + "(volume-preserving), unlike blurring the density. 0 = raw surface."), + IO.Float.Input("level", default=0.4, min=0.0, max=2.0, step=0.01, + tooltip="Iso-surface level. Auto-picked by Otsu; this biases it (1.0 = auto, lower = " + "fatter/more-connected surface, higher = thinner/tighter)."), + IO.Int.Input("min_component", default=500, min=0, max=100000, step=50, advanced=True, + tooltip="Drop connected components smaller than this many vertices (0 = keep all). " + "Removes detached floater blobs and the inner shell of the double wall."), + IO.Float.Input("min_opacity", default=0.02, min=0.0, max=1.0, step=0.01, advanced=True, + tooltip="Ignore gaussians fainter than this before meshing."), + IO.Float.Input("color_sharpen", default=2.0, min=1.0, max=8.0, step=0.5, + tooltip="Crisp up the vertex texture: 1.0 = physically-correct blend; higher biases " + "each voxel's colour toward its dominant gaussian instead of averaging " + "neighbours (de-smears the texture). Colour only - geometry is unchanged."), + ], + outputs=[IO.Mesh.Output(display_name="mesh")], + ) + + @classmethod + def execute(cls, splat, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen) -> IO.NodeOutput: + device = comfy.model_management.get_torch_device() + b = splat.positions.shape[0] + prec = 1000 # each splat owns a 0..prec block of the bar; its callback advances within that block + pbar = comfy.utils.ProgressBar(b * prec) + + verts_l, faces_l, colors_l = [], [], [] + for i in range(b): + cb = lambda f, base=i * prec: pbar.update_absolute(base + int(min(max(f, 0.0), 1.0) * prec)) + res = _gaussian_to_mesh(splat, i, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen, device, cb) + if res is None: + logging.warning("SplatToMesh: splat %d produced no surface; emitting an empty mesh.", i) + v, f, c = torch.zeros((0, 3)), torch.zeros((0, 3), dtype=torch.int64), torch.zeros((0, 3)) + else: + v, f, c = res + verts_l.append(v) + faces_l.append(f) + colors_l.append(c) + pbar.update_absolute((i + 1) * prec) # snap to block end (covers empty / early-out splats) + # unlit: render flat (emissive-like) so SaveGLB matches the splat instead of lighting/washing it. + return IO.NodeOutput(pack_variable_mesh_batch(verts_l, faces_l, colors=colors_l, unlit=True)) + + +class GaussianExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [SplatToFile3D, File3DToSplat, RenderSplat, CreateCameraInfo, TransformSplat, + GetSplatCount, MergeSplat, SplatToMesh] + + +async def comfy_entrypoint() -> GaussianExtension: + return GaussianExtension() diff --git a/comfy_extras/nodes_math.py b/comfy_extras/nodes_math.py index 873ee7b51..0883c65ac 100644 --- a/comfy_extras/nodes_math.py +++ b/comfy_extras/nodes_math.py @@ -102,11 +102,18 @@ class MathExpressionNode(io.ComfyNode): f"Math Expression '{expression}' must evaluate to a numeric result, " f"got {type(result).__name__}: {result!r}" ) - if not math.isfinite(result): + try: + float_result = float(result) + except OverflowError: + raise ValueError( + f"Math Expression '{expression}' produced a result too large to " + f"represent as a float: {result}" + ) from None + if not math.isfinite(float_result): raise ValueError( f"Math Expression '{expression}' produced a non-finite result: {result}" ) - return io.NodeOutput(float(result), int(result), bool(result)) + return io.NodeOutput(float_result, int(result), bool(result)) class MathExtension(ComfyExtension): diff --git a/comfy_extras/nodes_save_3d.py b/comfy_extras/nodes_save_3d.py index c03524246..a91549e7f 100644 --- a/comfy_extras/nodes_save_3d.py +++ b/comfy_extras/nodes_save_3d.py @@ -16,7 +16,7 @@ from comfy.cli_args import args from comfy_api.latest import ComfyExtension, IO, Types -def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None): +def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None, unlit=False): # Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors, # stashing per-item lengths as runtime attrs so consumers can recover the real slice. # colors and uvs are 1:1 with vertices, so they're padded to max_vertices and read with vertex_counts. @@ -54,7 +54,7 @@ def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=Non return Types.MESH(packed_vertices, packed_faces, uvs=packed_uvs, vertex_colors=packed_colors, texture=texture, - vertex_counts=vertex_counts, face_counts=face_counts) + vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit) def get_mesh_batch_item(mesh, index): @@ -77,7 +77,7 @@ def get_mesh_batch_item(mesh, index): def save_glb(vertices, faces, filepath, metadata=None, - uvs=None, vertex_colors=None, texture_image=None): + uvs=None, vertex_colors=None, texture_image=None, unlit=False): """ Save PyTorch tensor vertices and faces as a GLB file without external dependencies. @@ -234,6 +234,17 @@ def save_glb(vertices, faces, filepath, metadata=None, textures = [] samplers = [] materials = [] + extensions_used = [] + if unlit and texture_png_bytes is None: + # Flat, light-independent shading (KHR_materials_unlit): COLOR_0 is shown as-is, matching how a + # gaussian splat renders (emissive). Without this the viewer lights the mesh and washes the colours. + materials.append({ + "pbrMetallicRoughness": {"baseColorFactor": [1.0, 1.0, 1.0, 1.0], "metallicFactor": 0.0, "roughnessFactor": 1.0}, + "extensions": {"KHR_materials_unlit": {}}, + "doubleSided": True, + }) + extensions_used.append("KHR_materials_unlit") + primitive["material"] = 0 if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes: buffer_views.append({ "buffer": 0, @@ -271,6 +282,8 @@ def save_glb(vertices, faces, filepath, metadata=None, gltf["textures"] = textures if materials: gltf["materials"] = materials + if extensions_used: + gltf["extensionsUsed"] = extensions_used if metadata: gltf["asset"]["extras"] = metadata @@ -376,7 +389,8 @@ class SaveGLB(IO.ComfyNode): save_glb(vertices_i, faces_i, os.path.join(full_output_folder, f), metadata, uvs=uvs_i, vertex_colors=v_colors, - texture_image=tex_img) + texture_image=tex_img, + unlit=getattr(mesh, "unlit", False)) results.append({ "filename": f, "subfolder": subfolder, diff --git a/comfy_extras/nodes_triposplat.py b/comfy_extras/nodes_triposplat.py new file mode 100644 index 000000000..1848ad31a --- /dev/null +++ b/comfy_extras/nodes_triposplat.py @@ -0,0 +1,270 @@ +# TripoSplat nodes: image -> 3D gaussian splat + +import logging + +import torch +import torch.nn.functional as F +from typing_extensions import override + +import comfy.model_management +import comfy.nested_tensor +import comfy.patcher_extension +import comfy.utils +from comfy_api.latest import ComfyExtension, IO, Types + + +_Q_TOKEN_LENGTH = 8192 +_LATENT_CHANNELS = 16 +_CAM_CHANNELS = 5 +_DINOV3_MEAN = [0.485, 0.456, 0.406] +_DINOV3_STD = [0.229, 0.224, 0.225] +_NUM_GAUSSIANS_MIN = 32768 +_NUM_GAUSSIANS_MAX = 1048576 + + +def _preprocess(image: torch.Tensor, mask: torch.Tensor, erode_radius: int, size: int) -> torch.Tensor: + # Match original preprocessing: + # resize min side to `size` -> erode alpha -> alpha bbox -> 1.2x square crop -> resize -> composite on black. + rgb = image[..., :3].clamp(0, 1).movedim(-1, 0) # (3, H, W) + alpha = mask.clamp(0, 1)[None] # (1, H, W) + rgba = torch.cat([rgb, alpha], 0)[None] # (1, 4, H, W) + + h, w = rgba.shape[-2:] + s = size / min(w, h) + rgba = comfy.utils.common_upscale(rgba, max(1, round(w * s)), max(1, round(h * s)), "lanczos", "disabled").clamp(0, 1) + + a = rgba[:, 3:4] + if erode_radius > 0: + # min filter over a (2r+1) window == morphological erosion of the alpha matte. + a = -F.max_pool2d(-a, 2 * erode_radius + 1, stride=1, padding=erode_radius) + rgba = torch.cat([rgba[:, :3], a], 1) + + ys, xs = torch.nonzero(a[0, 0] > 0, as_tuple=True) + if xs.numel() == 0: + raise ValueError("TripoSplatPreprocessImage: mask is empty (no foreground pixels).") + x0, x1 = int(xs.min()), int(xs.max()) + y0, y1 = int(ys.min()), int(ys.max()) + cx, cy = (x0 + x1) / 2, (y0 + y1) / 2 + half = max(x1 - x0, y1 - y0) / 2 * 1.2 + left, upper, right, lower = int(cx - half), int(cy - half), int(cx + half), int(cy + half) + + H, W = rgba.shape[-2:] + crop = rgba.new_zeros((1, 4, lower - upper, right - left)) # out-of-bounds stays 0, matching PIL.crop + sx0, sy0, sx1, sy1 = max(left, 0), max(upper, 0), min(right, W), min(lower, H) + if sx1 > sx0 and sy1 > sy0: + crop[:, :, sy0 - upper:sy1 - upper, sx0 - left:sx1 - left] = rgba[:, :, sy0:sy1, sx0:sx1] + + crop = comfy.utils.common_upscale(crop, size, size, "lanczos", "disabled").clamp(0, 1) + out = (crop[:, :3] * crop[:, 3:4])[0].movedim(0, -1) # composite over black == rgb * alpha + return out.unsqueeze(0) # (1, 1024, 1024, 3) + + +class TripoSplatPreprocessImage(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoSplatPreprocessImage", + display_name="TripoSplat Preprocess Image", + category="3d/conditioning", + description="Crop center each image to a square canvas on a black background and add padding.", + inputs=[ + IO.Image.Input("image"), + IO.Mask.Input("mask"), + IO.Int.Input("erode_radius", default=1, min=0, max=16, + tooltip="Erode the alpha matte by this pixel radius before cropping (avoids border bleed)."), + IO.Int.Input("size", default=1024, min=256, max=4096, step=16, + tooltip="Square image size. The model is trained at 1024; other sizes run but are off-distribution."), + ], + outputs=[IO.Image.Output(display_name="image")], + ) + + @classmethod + def execute(cls, image, mask, erode_radius, size) -> IO.NodeOutput: + size = max(16, (int(size) // 16) * 16) # DINOv3 patch / Flux2 VAE stride is 16 + if mask.shape[0] != image.shape[0]: + mask = comfy.utils.repeat_to_batch_size(mask, image.shape[0]) + if tuple(mask.shape[1:]) != tuple(image.shape[1:3]): + mask = F.interpolate(mask[:, None].float(), size=tuple(image.shape[1:3]), mode="bilinear", align_corners=False)[:, 0] + prepared = torch.cat([_preprocess(image[i], mask[i], erode_radius, size) for i in range(image.shape[0])], dim=0) + return IO.NodeOutput(prepared) + + +class TripoSplatConditioning(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoSplatConditioning", + display_name="TripoSplat Conditioning", + category="3d/conditioning", + description="Encode the image with DINOv3 and the Flux2 VAE into TripoSplat positive/negative " + "conditioning, and create the fixed size noise target (latent + camera) for the KSampler", + inputs=[ + IO.ClipVision.Input("clip_vision", tooltip="DINOv3 ViT-H/16+ image encoder"), + IO.Vae.Input("vae", tooltip="Flux2 VAE"), + IO.Image.Input("image"), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + IO.Latent.Output(display_name="latent", tooltip="The fixed size noise target (latent +camera)."), + ], + ) + + @classmethod + def execute(cls, clip_vision, vae, image) -> IO.NodeOutput: + # feature1: DINOv3 token sequence (cls + registers + patches), ImageNet-normalized, with a final non-affine layer norm on top + comfy.model_management.load_model_gpu(clip_vision.patcher) + device = clip_vision.load_device + img = image.movedim(-1, 1).to(device) # (B,3,H,W) in [0,1] + mean = torch.tensor(_DINOV3_MEAN, device=device).view(1, 3, 1, 1) + std = torch.tensor(_DINOV3_STD, device=device).view(1, 3, 1, 1) + img = (img - mean) / std + seq = clip_vision.model(pixel_values=img.float())[0] + feature1 = F.layer_norm(seq.float(), seq.shape[-1:]).to(comfy.model_management.intermediate_device()) + + # Second conditioning: the Flux2 VAE latent of the image, carried as a standard reference_latents entry + ref = vae.encode(image).to(comfy.model_management.intermediate_device()) # (B, 128, H, W) + b = ref.shape[0] + + positive = [[feature1, {"reference_latents": [ref]}]] + negative = [[torch.zeros_like(feature1), {"reference_latents": [torch.zeros_like(ref)]}]] + + # Fixed noise target: the latent is a constant-shape (8192, 16) shape-code + a (1, 5) camera token + dev = comfy.model_management.intermediate_device() + latent_seq = torch.zeros([b, _Q_TOKEN_LENGTH, _LATENT_CHANNELS], device=dev) + camera = torch.zeros([b, 1, _CAM_CHANNELS], device=dev) + samples = comfy.nested_tensor.NestedTensor((latent_seq, camera)) + return IO.NodeOutput(positive, negative, {"samples": samples}) + + +class VAEDecodeTripoSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeTripoSplat", + display_name="TripoSplat Decode", + category="3d/latent", + description="Decode the sampled TripoSplat latent into a 3D gaussian splat. " + "Modify the number of gaussians to vary the density.", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"), + IO.Int.Input("num_gaussians", default=262144, min=_NUM_GAUSSIANS_MIN, max=_NUM_GAUSSIANS_MAX, step=32, + tooltip="Number of gaussians to produce (rounded to a multiple of 32). " + "262144 matches the octree's point density; higher oversamples the same points " + "(denser, but no new detail) and costs proportionally more VRAM/time."), + IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, + tooltip="Seeds the octree point sampler (global RNG) for deterministic decodes."), + ], + outputs=[IO.Splat.Output(display_name="splat")], + ) + + @classmethod + def execute(cls, samples, vae, num_gaussians, seed) -> IO.NodeOutput: + s = samples["samples"] + latent = s.unbind()[0] if getattr(s, "is_nested", False) else s # take the latent stream, drop camera + + decoder = vae.first_stage_model + gpp = decoder.gaussians_per_point + n = max(_NUM_GAUSSIANS_MIN, min(_NUM_GAUSSIANS_MAX, int(num_gaussians))) + if n % gpp != 0: + n = round(n / gpp) * gpp + + dtype_size = comfy.model_management.dtype_size(vae.vae_dtype) + hidden = decoder.gs.model_channels + cond_tokens = latent.shape[1] + memory_required = (cond_tokens * 4 + (n // gpp) * 10) * hidden * dtype_size + comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required) + latent = latent.to(device=vae.device, dtype=vae.vae_dtype) + generator = torch.Generator(device="cpu").manual_seed(seed) + parts = [g.render_tensors() for g in decoder.decode(latent, num_gaussians=n, generator=generator)] + positions, scales, rotations, opacities, sh = (torch.stack(t) for t in zip(*parts)) + return IO.NodeOutput(Types.SPLAT(positions, scales, rotations, opacities, sh)) + + +class TripoSplatSamplingPreview(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoSplatSamplingPreview", + display_name="TripoSplat Sampling Preview", + category="3d/latent", + description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded " + "gaussian splat preview at each step.", + inputs=[ + IO.Model.Input("model"), + IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"), + IO.Int.Input("octree_level", default=5, min=2, max=8, advanced=True, + tooltip="Octree depth for the preview decode (lower = cheaper/coarser)."), + IO.Int.Input("num_gaussians", default=16384, min=1024, max=262144, step=32, + tooltip="Number of gaussians to produce for the preview (rounded to a multiple of 32)."), + IO.Float.Input("yaw", default=90.0, min=-360.0, max=360.0, step=1.0, tooltip="Preview camera yaw in degrees.", advanced=True,), + IO.Float.Input("pitch", default=15.0, min=-89.0, max=89.0, step=1.0, tooltip="Preview camera pitch in degrees.", advanced=True,), + IO.Int.Input("point_size", default=3, min=1, max=16, + tooltip="Maximum splat radius in pixels. Each gaussian is sized from its scale and capped here; " + "lower = finer/pointier, higher = chunkier."), + ], + outputs=[IO.Model.Output()], + ) + + @classmethod + def execute(cls, model, vae, octree_level, num_gaussians, yaw, pitch, point_size) -> IO.NodeOutput: + from comfy.ldm.triposplat.preview import decode_x0_to_image + cfg = {"gaussians": num_gaussians, "level": octree_level, "yaw": yaw, "pitch": pitch, + "point_size": point_size} + + fsm = vae.first_stage_model + cond_tokens = model.model.diffusion_model.q_token_length + memory_required = (cond_tokens * 4 + (num_gaussians // fsm.gaussians_per_point) * 10) * fsm.gs.model_channels * comfy.model_management.dtype_size(vae.vae_dtype) + + # Live preview via WrappersMP.OUTER_SAMPLE + ProgressBar + # The wrapper augments the sampler's own callback to decode x0 -> gaussian splat -> preview image each step + def outer_sample_wrapper(executor, *args, **kwargs): + args = list(args) + cb_idx = 5 # outer_sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + orig_cb = args[cb_idx] if len(args) > cb_idx else kwargs.get("callback") + state = {"ok": True, "pbar": None, "loaded": False} + + def callback(step, x0, x, total_steps): + if orig_cb is not None: + orig_cb(step, x0, x, total_steps) + if not state["ok"]: + return + try: + if not state["loaded"]: + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + loaded_models.append(vae.patcher) + comfy.model_management.load_models_gpu(loaded_models, memory_required=memory_required) + state["loaded"] = True + img = decode_x0_to_image(vae, x0, cfg) + if state["pbar"] is None: + state["pbar"] = comfy.utils.ProgressBar(total_steps) + state["pbar"].update_absolute(step + 1, total_steps, ("JPEG", img, 512)) + except Exception as e: + logging.warning("TripoSplatSamplingPreview: preview failed, disabling ({})".format(e)) + state["ok"] = False + + if len(args) > cb_idx: + args[cb_idx] = callback + else: + kwargs["callback"] = callback + return executor(*args, **kwargs) + + m = model.clone() + m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "triposplat_sampling_preview", outer_sample_wrapper) + return IO.NodeOutput(m) + + +class TripoSplatExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TripoSplatPreprocessImage, + TripoSplatConditioning, + VAEDecodeTripoSplat, + TripoSplatSamplingPreview, + ] + + +async def comfy_entrypoint() -> TripoSplatExtension: + return TripoSplatExtension() diff --git a/comfyui_version.py b/comfyui_version.py index 0bb0f780c..19e8f8cfc 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.22.0" +__version__ = "0.23.0" diff --git a/main.py b/main.py index bce451a83..239a52013 100644 --- a/main.py +++ b/main.py @@ -464,13 +464,6 @@ def start_comfyui(asyncio_loop=None): folder_paths.set_temp_directory(temp_dir) cleanup_temp() - if args.windows_standalone_build: - try: - import new_updater - new_updater.update_windows_updater() - except: - pass - if not asyncio_loop: asyncio_loop = asyncio.new_event_loop() asyncio.set_event_loop(asyncio_loop) diff --git a/new_updater.py b/new_updater.py deleted file mode 100644 index 9a203acdd..000000000 --- a/new_updater.py +++ /dev/null @@ -1,35 +0,0 @@ -import os -import shutil - -base_path = os.path.dirname(os.path.realpath(__file__)) - - -def update_windows_updater(): - top_path = os.path.dirname(base_path) - updater_path = os.path.join(base_path, ".ci/update_windows/update.py") - bat_path = os.path.join(base_path, ".ci/update_windows/update_comfyui.bat") - - dest_updater_path = os.path.join(top_path, "update/update.py") - dest_bat_path = os.path.join(top_path, "update/update_comfyui.bat") - dest_bat_deps_path = os.path.join(top_path, "update/update_comfyui_and_python_dependencies.bat") - - try: - with open(dest_bat_path, 'rb') as f: - contents = f.read() - except: - return - - if not contents.startswith(b"..\\python_embeded\\python.exe .\\update.py"): - return - - shutil.copy(updater_path, dest_updater_path) - try: - with open(dest_bat_deps_path, 'rb') as f: - contents = f.read() - contents = contents.replace(b'..\\python_embeded\\python.exe .\\update.py ..\\ComfyUI\\', b'call update_comfyui.bat nopause') - with open(dest_bat_deps_path, 'wb') as f: - f.write(contents) - except: - pass - shutil.copy(bat_path, dest_bat_path) - print("Updated the windows standalone package updater.") # noqa: T201 diff --git a/nodes.py b/nodes.py index 528bf316f..331425b87 100644 --- a/nodes.py +++ b/nodes.py @@ -2455,6 +2455,8 @@ async def init_builtin_extra_nodes(): "nodes_save_3d.py", "nodes_moge.py", "nodes_mediapipe.py", + "nodes_gaussian_splat.py", + "nodes_triposplat.py" ] import_failed = [] diff --git a/pyproject.toml b/pyproject.toml index 1e449b4a3..e118800e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.22.0" +version = "0.23.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" diff --git a/requirements.txt b/requirements.txt index 14bba1437..b09d31a8b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.44.19 -comfyui-workflow-templates==0.9.91 +comfyui-workflow-templates==0.9.92 comfyui-embedded-docs==0.5.2 torch torchsde diff --git a/tests-unit/comfy_extras_test/nodes_math_test.py b/tests-unit/comfy_extras_test/nodes_math_test.py index 714e37c32..030accc5e 100644 --- a/tests-unit/comfy_extras_test/nodes_math_test.py +++ b/tests-unit/comfy_extras_test/nodes_math_test.py @@ -197,3 +197,10 @@ class TestMathExpressionExecute: def test_pow_huge_exponent_raises(self): with pytest.raises(ValueError, match="Exponent .* exceeds maximum"): self._exec("pow(a, b)", a=10, b=10000000) + + def test_huge_int_result_raises_value_error(self): + # Exponent is within the allowed MAX_EXPONENT range, so the result is a + # finite Python int that is nonetheless too large to convert to float. + # This must raise a clean ValueError, not an uncaught OverflowError. + with pytest.raises(ValueError, match="too large to represent as a float"): + self._exec("2 ** 3999")