From 462c27fdb2b84e612cdd4b3c7fac8875b04eda43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:01:50 +0300 Subject: [PATCH] feat: Add TripoSplat support (#14210) --- comfy/clip_vision.py | 16 +- comfy/image_encoders/dino3.py | 260 ++++++++++++++++++ comfy/latent_formats.py | 10 + comfy/ldm/triposplat/gaussian.py | 199 ++++++++++++++ comfy/ldm/triposplat/model.py | 326 +++++++++++++++++++++++ comfy/ldm/triposplat/preview.py | 91 +++++++ comfy/ldm/triposplat/vae.py | 382 +++++++++++++++++++++++++++ comfy/model_base.py | 19 ++ comfy/model_detection.py | 3 + comfy/sd.py | 11 + comfy/supported_models.py | 25 ++ comfy_extras/nodes_gaussian_splat.py | 3 +- comfy_extras/nodes_triposplat.py | 269 +++++++++++++++++++ nodes.py | 1 + 14 files changed, 1612 insertions(+), 3 deletions(-) create mode 100644 comfy/image_encoders/dino3.py create mode 100644 comfy/ldm/triposplat/gaussian.py create mode 100644 comfy/ldm/triposplat/model.py create mode 100644 comfy/ldm/triposplat/preview.py create mode 100644 comfy/ldm/triposplat/vae.py create mode 100644 comfy_extras/nodes_triposplat.py diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 1691fca81..337575191 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -2,6 +2,7 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_repl import os import json import logging +import torch import comfy.ops import comfy.model_patcher @@ -9,6 +10,7 @@ import comfy.model_management import comfy.utils import comfy.clip_model import comfy.image_encoders.dino2 +import comfy.image_encoders.dino3 class Output: def __getitem__(self, key): @@ -23,12 +25,16 @@ IMAGE_ENCODERS = { "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection, "dinov2": comfy.image_encoders.dino2.Dinov2Model, + "dinov3": comfy.image_encoders.dino3.DINOv3ViTModel, } class ClipVisionModel(): def __init__(self, json_config): - with open(json_config) as f: - config = json.load(f) + if isinstance(json_config, dict): + config = json_config + else: + with open(json_config) as f: + config = json.load(f) self.image_size = config.get("image_size", 224) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) @@ -44,6 +50,10 @@ class ClipVisionModel(): self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) + if self.model_type == "dinov3" and self.dtype == torch.float16: + # DINOv3's activations borderline fits fp16, preferring bf16 if available for better stability #TODO: further fp16 tests in practice + if comfy.model_management.should_use_bf16(self.load_device, prioritize_performance=True): + self.dtype = torch.bfloat16 self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model.eval() @@ -134,6 +144,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json") elif 'encoder.layer.23.layer_scale2.lambda1' in sd: json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json") + elif 'layer.0.mlp.gate_proj.weight' in sd and 'layer.31.norm1.weight' in sd: # Dinov3 ViT-H/16+ (SwiGLU gated MLP, 32 layers) + json_config = comfy.image_encoders.dino3.DINOV3_VITH_CONFIG else: return None diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py new file mode 100644 index 000000000..9bd42a66b --- /dev/null +++ b/comfy/image_encoders/dino3.py @@ -0,0 +1,260 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale + + +# DINOv3 ViT-H/16+ (SwiGLU) +DINOV3_VITH_CONFIG = { + "model_type": "dinov3", + "num_hidden_layers": 32, + "hidden_size": 1280, + "num_attention_heads": 20, + "num_register_tokens": 4, + "intermediate_size": 5120, + "layer_norm_eps": 1e-5, + "num_channels": 3, + "patch_size": 16, + "rope_theta": 100.0, + "use_gated_mlp": True, + "gated_mlp_act": "silu", + "image_size": 1024, + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225], +} + + +class DINOv3ViTMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype) + self.act_fn = torch.nn.GELU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, **kwargs): + num_tokens = q.shape[-2] + num_patches = sin.shape[-2] + num_prefix_tokens = num_tokens - num_patches + + q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2) + k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2) + + q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin) + k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin) + + q = torch.cat((q_prefix_tokens, q_patches), dim=-2) + k = torch.cat((k_prefix_tokens, k_patches), dim=-2) + + return q, k + + +class DINOv3ViTAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, device, dtype, operations): + super().__init__() + self.embed_dim = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False + self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None, **kwargs): + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attn = optimized_attention_for_device(query_states.device, mask=False) + attn_output = attn( + query_states, key_states, value_states, self.num_heads, attention_mask, + skip_reshape=True, skip_output_reshape=True, low_precision_attention=False, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + +class DINOv3ViTGatedMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations, act="silu"): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype) + self.act_fn = torch.nn.SiLU() if act == "silu" else torch.nn.GELU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +def get_patches_center_coordinates(num_patches_h, num_patches_w, dtype, device): + coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device) + coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device) + coords_h = coords_h / num_patches_h + coords_w = coords_w / num_patches_w + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) + coords = coords.flatten(0, 1) + coords = 2.0 * coords - 1.0 + return coords + + +class DINOv3ViTRopePositionEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, rope_theta, hidden_size, num_attention_heads, patch_size, device, dtype): + super().__init__() + self.base = rope_theta + self.head_dim = hidden_size // num_attention_heads + self.patch_size = patch_size + + inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, pixel_values): + _, _, height, width = pixel_values.shape + num_patches_h = height // self.patch_size + num_patches_w = width // self.patch_size + + patch_coords = get_patches_center_coordinates(num_patches_h, num_patches_w, dtype=torch.float32, device=pixel_values.device) + self.inv_freq = self.inv_freq.to(pixel_values.device) + angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] + angles = angles.flatten(1, 2) + angles = angles.tile(2) + cos = torch.cos(angles).to(dtype=pixel_values.dtype) + sin = torch.sin(angles).to(dtype=pixel_values.dtype) + return cos, sin + + +class DINOv3ViTEmbeddings(nn.Module): + def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations): + super().__init__() + self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) + self.mask_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) + self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype)) + self.patch_embeddings = operations.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype + ) + + def forward(self, pixel_values, bool_masked_pos=None): + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embeddings.weight.dtype + + patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) + + if bool_masked_pos is not None: + mask_token = self.mask_token.to(patch_embeddings.dtype) + patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings) + + cls_token = self.cls_token.expand(batch_size, -1, -1).to(patch_embeddings.device) + register_tokens = self.register_tokens.expand(batch_size, -1, -1).to(patch_embeddings.device) + embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) + return embeddings + + +class DINOv3ViTLayer(nn.Module): + def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size, + num_attention_heads, device, dtype, operations, gated_mlp_act="silu"): + super().__init__() + self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) + self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations) + self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) + + self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) + if use_gated_mlp: + self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations, act=gated_mlp_act) + else: + self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations) + self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = self.attention(hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings) + hidden_states = self.layer_scale1(hidden_states) + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.layer_scale2(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class DINOv3ViTModel(nn.Module): + def __init__(self, config, dtype, device, operations): + super().__init__() + num_hidden_layers = config["num_hidden_layers"] + hidden_size = config["hidden_size"] + num_attention_heads = config["num_attention_heads"] + num_register_tokens = config["num_register_tokens"] + intermediate_size = config["intermediate_size"] + layer_norm_eps = config["layer_norm_eps"] + num_channels = config["num_channels"] + patch_size = config["patch_size"] + rope_theta = config["rope_theta"] + use_gated_mlp = config.get("use_gated_mlp", False) + gated_mlp_act = config.get("gated_mlp_act", "silu") + + self.embeddings = DINOv3ViTEmbeddings( + hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size, + dtype=dtype, device=device, operations=operations + ) + self.rope_embeddings = DINOv3ViTRopePositionEmbedding( + rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device + ) + self.layer = nn.ModuleList([ + DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=use_gated_mlp, mlp_bias=True, + intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, + dtype=dtype, device=device, operations=operations, gated_mlp_act=gated_mlp_act) + for _ in range(num_hidden_layers)]) + self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device) + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward(self, pixel_values, bool_masked_pos=None, **kwargs): + pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + position_embeddings = self.rope_embeddings(pixel_values) + + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, position_embeddings=position_embeddings) + + if kwargs.get("skip_norm_elementwise", False): + sequence_output = F.layer_norm(hidden_states, hidden_states.shape[-1:]) + else: + norm = self.norm.to(hidden_states.device) + sequence_output = norm(hidden_states) + pooled_output = sequence_output[:, 0, :] + return sequence_output, None, pooled_output, None diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 12a934d71..bbdfd4bc2 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -239,6 +239,16 @@ class Flux2(LatentFormat): def process_out(self, latent): return latent +class TripoSplat(LatentFormat): + # Sequence latent (B, 8192, 16) the camera token rides alongside as a second nested latent + latent_channels = 16 + + def process_in(self, latent): + return latent + + def process_out(self, latent): + return latent + class Mochi(LatentFormat): latent_channels = 12 latent_dimensions = 3 diff --git a/comfy/ldm/triposplat/gaussian.py b/comfy/ldm/triposplat/gaussian.py new file mode 100644 index 000000000..a4cd2f62f --- /dev/null +++ b/comfy/ldm/triposplat/gaussian.py @@ -0,0 +1,199 @@ +# TripoSplat 3D gaussian container. Operates on already-decoded +# tensors and exposes them as render-ready tensors (render_tensors) for the generic SPLAT type. +import torch +import torch.nn.functional as F + +import comfy.model_management + + +class GaussianModel: + def __init__(self, aabb: list, sh_degree: int = 0, mininum_kernel_size: float = 0.0, + scaling_bias: float = 0.01, opacity_bias: float = 0.1, + scaling_activation: str = "exp", device=None): + self.sh_degree = sh_degree + self.mininum_kernel_size = mininum_kernel_size + self.scaling_bias = scaling_bias + self.opacity_bias = opacity_bias + self.device = device + self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + + if scaling_activation == "exp": + self._scaling_activation = torch.exp + self._inverse_scaling_activation = torch.log + elif scaling_activation == "softplus": + self._scaling_activation = F.softplus + self._inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x)) + + self._opacity_activation = torch.sigmoid + self._inverse_opacity_activation = lambda x: torch.log(x / (1 - x)) + + self.scale_bias = self._inverse_scaling_activation(torch.tensor(self.scaling_bias)).to(self.device) + self.rots_bias = torch.zeros(4, device=self.device) + self.rots_bias[0] = 1 + self.opacity_bias_val = self._inverse_opacity_activation(torch.tensor(self.opacity_bias)).to(self.device) + + self._storage = {} + + def _get_store(self, name): + return self._storage.get(name) + + def _set_store(self, name, value): + self._storage[name] = value + + @property + def _xyz(self): + return self._get_store("_xyz") + @_xyz.setter + def _xyz(self, value): + if value is None: + self._set_store("_xyz", None) + self._set_store("xyz", None) + return + self._set_store("_xyz", value) + self._set_store("xyz", value * self.aabb[None, 3:] + self.aabb[None, :3]) + + @property + def get_xyz(self): + return self._get_store("xyz") + + @property + def _features_dc(self): + return self._get_store("_features_dc") + @_features_dc.setter + def _features_dc(self, value): + self._set_store("_features_dc", value) + + @property + def _opacity(self): + return self._get_store("_opacity") + @_opacity.setter + def _opacity(self, value): + if value is None: + self._set_store("_opacity", None) + self._set_store("opacity", None) + return + self._set_store("_opacity", value) + self._set_store("opacity", self._opacity_activation(value + self.opacity_bias_val)) + + @property + def get_opacity(self): + return self._get_store("opacity") + + @property + def _scaling(self): + return self._get_store("_scaling") + @_scaling.setter + def _scaling(self, value): + if value is None: + self._set_store("_scaling", None) + self._set_store("scaling", None) + return + self._set_store("_scaling", value) + s = self._scaling_activation(value + self.scale_bias) + s = torch.square(s) + self.mininum_kernel_size ** 2 + self._set_store("scaling", torch.sqrt(s)) + + @property + def get_scaling(self): + return self._get_store("scaling") + + @property + def _rotation(self): + return self._get_store("_rotation") + @_rotation.setter + def _rotation(self, value): + self._set_store("_rotation", value) + + _DEFAULT_TRANSFORM = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] + + def render_tensors(self): + # Render-ready (activated, world-space) tensors for the generic SPLAT type. The axis transform + # (a 3x3 rotation, object frame -> viewer Y-up) is baked into positions and rotations. + # Returns float tensors on the intermediate device: positions (N,3), scales (N,3) linear, + # rotations (N,4) wxyz, opacities (N,1) in [0,1], sh (N,K,3) coefficients. + xyz = self.get_xyz.float() + scaling = self.get_scaling.float() + opacity = self.get_opacity.float() + rotation = (self._rotation + self.rots_bias[None, :]).float() + sh = self._features_dc.float() # (N, K, 3) + T = torch.as_tensor(self._DEFAULT_TRANSFORM, dtype=torch.float32, device=xyz.device) + xyz = xyz @ T.T + rotation = _matrix_to_quat(torch.matmul(T, _quat_to_matrix(rotation))) + rotation = rotation / torch.linalg.norm(rotation, dim=-1, keepdim=True) + out_device = comfy.model_management.intermediate_device() + return ( + xyz.to(out_device).contiguous(), scaling.to(out_device).contiguous(), + rotation.to(out_device).contiguous(), opacity.to(out_device).contiguous(), + sh.to(out_device).contiguous(), + ) + + +def _quat_to_matrix(q): + q = q / torch.linalg.norm(q, dim=-1, keepdim=True) + w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3] + R = torch.stack([ + 1 - 2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y), + 2*(x*y + w*z), 1 - 2*(x*x + z*z), 2*(y*z - w*x), + 2*(x*z - w*y), 2*(y*z + w*x), 1 - 2*(x*x + y*y), + ], dim=-1).reshape(-1, 3, 3) + return R + + +def _matrix_to_quat(R): + trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2] + q = torch.zeros((R.shape[0], 4), dtype=R.dtype, device=R.device) + s = torch.sqrt(torch.clamp(trace + 1, min=0)) * 2 + q[:, 0] = 0.25 * s + denom = torch.where(s != 0, s, torch.ones_like(s)) + q[:, 1] = (R[:, 2, 1] - R[:, 1, 2]) / denom + q[:, 2] = (R[:, 0, 2] - R[:, 2, 0]) / denom + q[:, 3] = (R[:, 1, 0] - R[:, 0, 1]) / denom + m01 = (R[:, 0, 0] >= R[:, 1, 1]) & (R[:, 0, 0] >= R[:, 2, 2]) & (s == 0) + s1 = torch.sqrt(torch.clamp(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2], min=0)) * 2 + q[m01, 0] = (R[m01, 2, 1] - R[m01, 1, 2]) / s1[m01] + q[m01, 1] = 0.25 * s1[m01] + q[m01, 2] = (R[m01, 0, 1] + R[m01, 1, 0]) / s1[m01] + q[m01, 3] = (R[m01, 0, 2] + R[m01, 2, 0]) / s1[m01] + m11 = (R[:, 1, 1] > R[:, 0, 0]) & (R[:, 1, 1] >= R[:, 2, 2]) & (s == 0) + s2 = torch.sqrt(torch.clamp(1 + R[:, 1, 1] - R[:, 0, 0] - R[:, 2, 2], min=0)) * 2 + q[m11, 0] = (R[m11, 0, 2] - R[m11, 2, 0]) / s2[m11] + q[m11, 1] = (R[m11, 0, 1] + R[m11, 1, 0]) / s2[m11] + q[m11, 2] = 0.25 * s2[m11] + q[m11, 3] = (R[m11, 1, 2] + R[m11, 2, 1]) / s2[m11] + m21 = (R[:, 2, 2] > R[:, 0, 0]) & (R[:, 2, 2] > R[:, 1, 1]) & (s == 0) + s3 = torch.sqrt(torch.clamp(1 + R[:, 2, 2] - R[:, 0, 0] - R[:, 1, 1], min=0)) * 2 + q[m21, 0] = (R[m21, 1, 0] - R[m21, 0, 1]) / s3[m21] + q[m21, 1] = (R[m21, 0, 2] + R[m21, 2, 0]) / s3[m21] + q[m21, 2] = (R[m21, 1, 2] + R[m21, 2, 1]) / s3[m21] + q[m21, 3] = 0.25 * s3[m21] + return q / torch.linalg.norm(q, dim=-1, keepdim=True) + + +def build_gaussian_models(decoder, points_pred: dict, pred: dict): + # Assemble GaussianModels from the elastic decoder layout. decoder is the ElasticGaussianFixedlenDecoder + # (carries layout / rep_config / _get_offset) + x = points_pred + offset = decoder._get_offset(pred['features']) + h = pred["features"] + ret = [] + for i in range(h.shape[0]): + g = GaussianModel( + sh_degree=0, + aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0], + mininum_kernel_size=decoder.rep_config['filter_kernel_size_3d'], + scaling_bias=decoder.rep_config['scaling_bias'], + opacity_bias=decoder.rep_config['opacity_bias'], + scaling_activation=decoder.rep_config['scaling_activation'], + device=h.device, + ) + _x = x["points"][i, :, None, :] + for k, v in decoder.layout.items(): + if k == '_xyz': + setattr(g, k, (offset[i] + _x).flatten(0, 1)) + elif k in ('_xyz_center', '_offset_scale'): + continue + else: + feats = h[i][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1) + setattr(g, k, feats * decoder.rep_config['lr'][k]) + ret.append(g) + return ret diff --git a/comfy/ldm/triposplat/model.py b/comfy/ldm/triposplat/model.py new file mode 100644 index 000000000..d8a531772 --- /dev/null +++ b/comfy/ldm/triposplat/model.py @@ -0,0 +1,326 @@ +# TripoSplat flow-matching denoiser (LatentSeqMMFlowModel). Registered as a ModelType.FLOW arch and +# driven by the standard KSampler; jointly denoises the (B, 8192, 16) latent and a (B, 1, 5) camera token +# carried as a 2-element nested latent. +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.model_management +import comfy.patcher_extension +import comfy.rmsnorm +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope + + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim, heads, dtype=None, device=None): + super().__init__() + self.gamma = nn.Parameter(torch.empty(heads, dim, dtype=dtype, device=device)) + + def forward(self, x): + x = comfy.rmsnorm.rms_norm(x) + return x * comfy.model_management.cast_to(self.gamma, x.dtype, x.device) + + +# Positional embeddings + +class RePo3DRotaryEmbedding(nn.Module): + def __init__(self, model_channels, num_heads, head_dim, repo_hidden_ratio=0.125, max_freq=16.0, + dtype=None, device=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + repo_hidden_size = int(model_channels * repo_hidden_ratio) + self.norm = operations.LayerNorm(model_channels, dtype=dtype, device=device) + self.gate_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device) + self.content_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device) + self.act = nn.SiLU() + self.final_map = operations.Linear(repo_hidden_size, 3 * num_heads, bias=False, dtype=dtype, device=device) + self.dim_0 = 2 * (head_dim // 6) + self.dim_1 = 2 * (head_dim // 6) + self.dim_2 = head_dim - self.dim_0 - self.dim_1 + dims = [self.dim_0, self.dim_1, self.dim_2] + freqs_list = [] + for d in dims: + freq_dim = d // 2 + freqs_list.append(torch.linspace(1.0, float(max_freq), steps=freq_dim, dtype=torch.float32)) + self.freqs_0 = nn.Parameter(freqs_list[0]) + self.freqs_1 = nn.Parameter(freqs_list[1]) + self.freqs_2 = nn.Parameter(freqs_list[2]) + + def forward(self, hidden_states): + h = self.norm(hidden_states) + feat = self.act(self.gate_map(h)) * self.content_map(h) + out = self.final_map(feat) + B, L, _ = out.shape + delta_pos = out.reshape(B, L, self.num_heads, 3) + f0 = comfy.model_management.cast_to(self.freqs_0, torch.float32, out.device) + f1 = comfy.model_management.cast_to(self.freqs_1, torch.float32, out.device) + f2 = comfy.model_management.cast_to(self.freqs_2, torch.float32, out.device) + ang_0 = delta_pos[..., 0].unsqueeze(-1) * f0 * torch.pi + ang_1 = delta_pos[..., 1].unsqueeze(-1) * f1 * torch.pi + ang_2 = delta_pos[..., 2].unsqueeze(-1) * f2 * torch.pi + ang = torch.cat([ang_0, ang_1, ang_2], dim=-1).float() # (B, L, heads, head_dim/2) + cos, sin = ang.cos(), ang.sin() + return torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*ang.shape, 2, 2) + + +class PcdAbsolutePositionEmbedder(nn.Module): + # Sinusoidal absolute position embedding. Two fixed schedules are used in TripoSplat: + # "pow2" (flow-model latent anchors) and "log2" (octree / gaussian decoders). + def __init__(self, channels: int, in_channels: int = 3, max_res: int = 16, schedule: str = "pow2"): + super().__init__() + self.channels = channels + self.in_channels = in_channels + self.max_res = max_res + self.schedule = schedule + self.freq_dim = channels // in_channels // 2 + + def _freqs(self, device): + if self.schedule == "pow2": + freqs_2exp = torch.arange(self.max_res, dtype=torch.float32, device=device) + res_dim = max(0, self.freq_dim - self.max_res) + freqs_res = (torch.arange(res_dim, dtype=torch.float32, device=device) / max(res_dim, 1) * self.max_res + if res_dim > 0 else torch.empty(0, device=device)) + freqs = torch.cat([freqs_2exp, freqs_res], dim=0)[:self.freq_dim] + return torch.pow(2.0, freqs) * 2.0 # *2 folds this schedule's 2*pi into the shared *pi below + logs = torch.linspace(0.0, float(self.max_res), steps=self.freq_dim, dtype=torch.float32, device=device) + return torch.pow(2.0, logs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + x = x.float() + *dims, D = x.shape + out = torch.outer(x.reshape(-1), self._freqs(x.device)) * torch.pi + out = torch.cat([out.sin(), out.cos()], dim=-1).reshape(*dims, -1) + if out.shape[-1] < self.channels: + out = torch.cat([out, torch.zeros(*dims, self.channels - out.shape[-1], + device=out.device, dtype=out.dtype)], dim=-1) + return out.to(orig_dtype) + + +def attention(q, k, v, transformer_options=None): + # q, k, v: (B, L, heads, dim) -> (B, L, heads, dim). Shared optimized_attention call convention. + out = optimized_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), heads=q.shape[2], + skip_reshape=True, skip_output_reshape=True, low_precision_attention=False, + transformer_options=transformer_options) + return out.transpose(1, 2) + + +# Transformer building blocks + +class MLP(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.mlp = nn.Sequential( + operations.Linear(in_channels, hidden_channels, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(hidden_channels, out_channels, dtype=dtype, device=device), + ) + + def forward(self, x): + return self.mlp(x) + + +class RopeMultiHeadAttention(nn.Module): + def __init__(self, channels, num_heads, qkv_bias=True, qk_rms_norm=False, use_rope=False, + dtype=None, device=None, operations=None): + super().__init__() + self.channels = channels + self.num_heads = num_heads + self.head_dim = channels // num_heads + self.qk_rms_norm = qk_rms_norm + self.use_rope = use_rope + self.qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device) + if self.qk_rms_norm: + self.q_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device) + self.k_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device) + self.out = operations.Linear(channels, channels, dtype=dtype, device=device) + + def forward(self, x, rope_emb=None, transformer_options=None): + B, L, C = x.shape + qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(2) + if self.use_rope: + q, k = apply_rope(q, k, rope_emb) + if self.qk_rms_norm: + q = self.q_norm(q) + k = self.k_norm(k) + h = attention(q, k, v, transformer_options) # (B, L, heads, dim) + return self.out(h.reshape(B, L, C)) + + +class UnifiedTransformerBlock(nn.Module): + def __init__(self, channels, num_heads, mlp_ratio=4.0, + use_rope=False, qk_rms_norm=False, qkv_bias=True, + modulation=True, share_mod=False, + dtype=None, device=None, operations=None): + super().__init__() + self.modulation = modulation + self.share_mod = share_mod + self.norm1 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device) + self.attn = RopeMultiHeadAttention(channels, num_heads=num_heads, + qkv_bias=qkv_bias, use_rope=use_rope, qk_rms_norm=qk_rms_norm, + dtype=dtype, device=device, operations=operations) + self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations) + if modulation: + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)) + self.shift_table = nn.Parameter(torch.empty(1, 6 * channels, dtype=dtype, device=device)) + + def forward(self, x, mod=None, rotary_emb=None, transformer_options=None): + if self.modulation: + if not self.share_mod: + mod = self.adaLN_modulation(mod) + mod = mod + comfy.model_management.cast_to(self.shift_table, mod.dtype, mod.device) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1)) + x = torch.addcmul(x, self.attn(h, rope_emb=rotary_emb, transformer_options=transformer_options), gate_msa.unsqueeze(1)) + h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1)) + x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1)) + else: + x = x + self.attn(self.norm1(x), rope_emb=rotary_emb, transformer_options=transformer_options) + x = x + self.mlp(self.norm2(x)) + return x + + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + super().__init__() + self.mlp = nn.Sequential( + operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + emb = self.timestep_embedding(t, self.frequency_embedding_size) + return self.mlp(emb.to(self.mlp[0].weight.dtype)) + + +class LatentSeqMMFlowModel(nn.Module): + def __init__(self, image_model=None, q_token_length=8192, in_channels=16, model_channels=1024, + cond_channels=1280, out_channels=16, num_blocks=24, num_refiner_blocks=2, + num_heads=None, num_head_channels=64, cam_channels=5, cond2_channels=128, + mlp_ratio=4, share_mod=True, qk_rms_norm=True, + dtype=None, device=None, operations=None, **kwargs): + super().__init__() + self.dtype = dtype + self.q_token_length = q_token_length + self.in_channels = in_channels + self.cam_channels = cam_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.cond2_channels = cond2_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_refiner_blocks = num_refiner_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + + factory_kwargs = dict(dtype=dtype, device=device) + op_kwargs = dict(operations=operations, **factory_kwargs) + + self.t_embedder = TimestepEmbedder(model_channels, **op_kwargs) + if share_mod: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, **factory_kwargs)) + + self.input_layer = operations.Linear(in_channels, model_channels, **factory_kwargs) + self.cond_embedder = operations.Linear(cond_channels, model_channels, **factory_kwargs) + self.cond_embedder2 = operations.Linear(cond2_channels, model_channels, **factory_kwargs) if cond2_channels is not None else None + + # Fixed Sobol (low-discrepancy) 3D anchor positions for the latent tokens, used as positional encoding. + # The embedder is parameter-free and the anchors are fixed, precompute once. + sobol_seq = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123).draw(q_token_length) + pos_emb = PcdAbsolutePositionEmbedder(model_channels)(sobol_seq.unsqueeze(0)) + self.register_buffer("pos_emb", pos_emb, persistent=False) + + # RePo3DRotaryEmbedding layers for the refiner and main blocks + repo_kwargs = dict(num_heads=self.num_heads, head_dim=num_head_channels, **op_kwargs) + self.noise_repo_layers = nn.ModuleList( + [RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)]) + self.context_repo_layers = nn.ModuleList( + [RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)]) + self.repo_layers = nn.ModuleList( + [RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_blocks)]) + + # Refiner blocks + block_kwargs = dict(num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, use_rope=True, qk_rms_norm=self.qk_rms_norm, **op_kwargs) + self.noise_refiner = nn.ModuleList( + [UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_refiner_blocks)]) + self.context_refiner = nn.ModuleList( + [UnifiedTransformerBlock(model_channels, modulation=False, **block_kwargs) for _ in range(num_refiner_blocks)]) + + self.cam_refiner = MLP(self.cam_channels, model_channels, model_channels, **op_kwargs) + + self.blocks = nn.ModuleList( + [UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_blocks)]) + + self.shift_table = nn.Parameter(torch.empty(1, 2, model_channels, **factory_kwargs)) + self.out_layer = operations.Linear(model_channels, out_channels, **factory_kwargs) + self.cam_out_layer = operations.Linear(model_channels, cam_channels, **factory_kwargs) + + def forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, t, context, ref_latents, transformer_options, **kwargs) + + def _forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs): + # x is the unpacked nested latent: [latent (B,8192,in_channels), camera (B,1,cam_channels)]. + # context == feature1. + z, camera = x[0], x[1] + feat1 = context + + h_x = self.input_layer(z) + h_cond = self.cond_embedder(feat1) + if ref_latents is not None and self.cond_embedder2 is not None: + # Flatten the Flux2 VAE latent (B,128,h,w) to a token sequence and front-pad to feat1's length + # (the pad count = feat1's prefix tokens: DINOv3 cls + registers), then add to the context. + feat2 = ref_latents[0].flatten(2).transpose(1, 2) + feat2 = F.pad(feat2, (0, 0, feat1.shape[1] - feat2.shape[1], 0)) + h_cond = h_cond + self.cond_embedder2(feat2.to(h_cond.dtype)) + t_emb = self.t_embedder(t) + t_mod = self.adaLN_modulation(t_emb) if self.share_mod else t_emb + + h_x = h_x + self.pos_emb.to(z) + + for i, block in enumerate(self.noise_refiner): + h_x = block(h_x, mod=t_mod, rotary_emb=self.noise_repo_layers[i](h_x), transformer_options=transformer_options) + + for i, block in enumerate(self.context_refiner): + h_cond = block(h_cond, mod=None, rotary_emb=self.context_repo_layers[i](h_cond), transformer_options=transformer_options) + + cam = camera.to(z) + h_cam = self.cam_refiner(cam) + h = torch.cat([h_x, h_cond, h_cam], dim=1) + + for i, block in enumerate(self.blocks): + h = block(h, mod=t_mod, rotary_emb=self.repo_layers[i](h), transformer_options=transformer_options) + + h_x = F.layer_norm(h[:, :z.shape[1]].float(), h.shape[-1:]).to(z) + h_cam = F.layer_norm(h[:, -cam.shape[1]:].float(), h.shape[-1:]).to(z) + + shift, scale = (comfy.model_management.cast_to(self.shift_table, t_emb.dtype, t_emb.device) + t_emb.unsqueeze(1)).chunk(2, dim=1) + scale = 1 + scale + h_x = torch.addcmul(shift, h_x, scale) + h_cam = torch.addcmul(shift, h_cam, scale) + + return self.out_layer(h_x), self.cam_out_layer(h_cam) diff --git a/comfy/ldm/triposplat/preview.py b/comfy/ldm/triposplat/preview.py new file mode 100644 index 000000000..6a942bb53 --- /dev/null +++ b/comfy/ldm/triposplat/preview.py @@ -0,0 +1,91 @@ +# Live preview for TripoSplat: decode an x0 estimate into a coarse gaussian splat and render it with a perspective orbit camera. +import numpy as np +from PIL import Image + +_C0 = 0.28209479177387814 +_LATENT_TOKENS = 8192 # q_token_length +_LATENT_CH = 16 # in_channels +_OBJECT_TO_VIEWER = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], np.float32) # object frame -> viewer Y-up frame + + +def _view_matrix(yaw_deg, pitch_deg): + y, p = np.radians(yaw_deg), np.radians(pitch_deg) + Ry = np.array([[np.cos(y), 0, np.sin(y)], [0, 1, 0], [-np.sin(y), 0, np.cos(y)]], np.float32) + Rx = np.array([[1, 0, 0], [0, np.cos(p), -np.sin(p)], [0, np.sin(p), np.cos(p)]], np.float32) + return Rx @ Ry + + +def render_splat(xyz, rgb, scale, opacity=None, yaw=35.0, pitch=30.0, size=320, min_px=2, gain=1.0, + max_px=9, min_opacity=0.0, fov=35.0, dist=2.2): + # Project gaussian centers with a perspective camera and paint each as a filled disk whose screen + # radius follows the gaussian's world-space scale, composited with a nearest-wins z-buffer. + # gain scales the footprint (≈ std spanned), `min_px`/`max_px` clamp the on-screen radius. + + pts = xyz.astype(np.float32) @ _OBJECT_TO_VIEWER.T + v = pts @ _view_matrix(yaw, pitch).T + zc = v[:, 2] + dist + keep = zc > 1e-2 + if opacity is not None and min_opacity > 0.0: # culls gaussians with very low opacity + keep = keep & (opacity > min_opacity) + v, zc, scale = v[keep], zc[keep], scale[keep] + col = (np.clip(rgb, 0, 1)[:, :3] * 255).astype(np.uint8)[keep] + if v.shape[0] == 0: + return Image.fromarray(np.zeros((size, size, 3), np.uint8)) + f = (size / 2) / np.tan(np.radians(fov) / 2) + cx = size / 2 + f * v[:, 0] / zc + cy = size / 2 + f * v[:, 1] / zc + radius = np.clip(np.round(f * scale / zc * gain), min_px, max_px).astype(np.int32) + + # Expand each splat to its disk pixels, bucketed by integer radius so it stays vectorized. + px, py, pz, pc = [], [], [], [] + for r in range(int(radius.min()), int(radius.max()) + 1): + m = radius == r + if not m.any(): + continue + dy, dx = np.mgrid[-r:r + 1, -r:r + 1] + disk = (dx * dx + dy * dy) <= r * r + ox, oy = dx[disk], dy[disk] + px.append((cx[m, None] + ox).ravel()) + py.append((cy[m, None] + oy).ravel()) + pz.append(np.repeat(zc[m], ox.size)) + pc.append(np.repeat(col[m], ox.size, axis=0)) + px, py = np.concatenate(px), np.concatenate(py) + pz, pc = np.concatenate(pz), np.concatenate(pc) + xi = np.clip(px, 0, size - 1).astype(np.int64) + yi = np.clip(py, 0, size - 1).astype(np.int64) + + # Nearest-wins z-buffer: pack (quantized depth, source index), per-pixel min picks the closest + # splat, then decode the winning index back to its color. + pid = yi * size + xi + q = np.clip((pz * 1024.0).astype(np.int64), 0, (1 << 20) - 1) # near = small + key = (q << 32) | np.arange(pid.size, dtype=np.int64) + buf = np.full(size * size, 1 << 62, np.int64) + np.minimum.at(buf, pid, key) + img = np.zeros((size * size, 3), np.uint8) + hit = buf < (1 << 62) + img[hit] = pc[buf[hit] & 0xFFFFFFFF] + return Image.fromarray(img.reshape(size, size, 3)) + + +def _extract_latent(x0): + # x0 from the sampler callback is the nested latent packed to (B, 1, TOKENS*CH + 1*5); + # the plain single-latent case is (B, TOKENS, CH). Return the (B, TOKENS, CH) latent stream. + if x0.ndim == 3 and x0.shape[1] == _LATENT_TOKENS and x0.shape[2] == _LATENT_CH: + return x0 + flat = x0.reshape(x0.shape[0], -1) + return flat[:, :_LATENT_TOKENS * _LATENT_CH].reshape(x0.shape[0], _LATENT_TOKENS, _LATENT_CH) + + +def decode_x0_to_image(decoder, x0, cfg): + # Decode x0 at a coarse octree level / few gaussians and render a preview image. + latent = _extract_latent(x0) + fsm = decoder.first_stage_model + gaussian = fsm.decode(latent.to(decoder.device, decoder.vae_dtype), + num_gaussians=cfg.get("gaussians", 16384), level=cfg.get("level", 5))[0] + xyz = gaussian.get_xyz.float().cpu().numpy() + rgb = gaussian._features_dc.float().cpu().numpy()[:, 0, :] * _C0 + 0.5 + scale = gaussian.get_scaling.float().cpu().numpy().max(axis=1) # per-splat world radius (largest axis) + opacity = gaussian.get_opacity.float().cpu().numpy()[:, 0] + return render_splat(xyz, rgb, scale, opacity=opacity, yaw=cfg.get("yaw", 35.0), pitch=cfg.get("pitch", 30.0), + size=cfg.get("size", 320), min_px=1, gain=1.0, max_px=cfg.get("point_size", 3), + min_opacity=0.01) diff --git a/comfy/ldm/triposplat/vae.py b/comfy/ldm/triposplat/vae.py new file mode 100644 index 000000000..e5ed9fd36 --- /dev/null +++ b/comfy/ldm/triposplat/vae.py @@ -0,0 +1,382 @@ +# TripoSplat gaussian decoder ("VAE"): an octree probability decoder picks point coords, then an +# elastic-gaussian decoder predicts per-point gaussian params. OctreeGaussianDecoder.decode() returns +# a Gaussian. The octree sampler uses the global torch RNG (no generator) like upstream, so seed it for repeatable decodes. +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.model_management +import comfy.ops +from .gaussian import build_gaussian_models +from .model import MultiHeadRMSNorm, MLP, PcdAbsolutePositionEmbedder, attention + + +# Quasi-random sampling utilities (pure functions, dtype/device-agnostic) + +PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + + +def radical_inverse(base, n): + val = 0 + inv_base = 1.0 / base + inv_base_n = inv_base + while n > 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[i], n) for i in range(dim)] + + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + + +def sample_probs(probs, counts, generator=None): + # Systematic resampling: distribute counts[r] draws across the P bins of row r + batch_shape = counts.shape + R = counts.numel() + P = probs.size(-1) + device = probs.device + probs = probs.reshape(R, P).to(torch.float32).clamp_min(0) + counts = counts.reshape(R).to(device=device, dtype=torch.long) + + row_sums = probs.sum(1, keepdim=True) + probs = torch.where(row_sums == 0, probs.new_tensor(1.0 / P), probs / row_sums.clamp_min(1)) + cdf = probs.cumsum(dim=1).clamp(max=1.0 - 1e-12) + + Nmax = int(counts.max()) + if Nmax == 0: + return counts.new_zeros(*batch_shape, P) + cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1) + grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax) + u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded) + idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1) + weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r] + out = torch.zeros(R, P, dtype=torch.float32, device=device) + out.scatter_add_(1, idx, weight) + return out.to(torch.long).view(*batch_shape, P) + + +class MultiHeadAttention(nn.Module): + def __init__(self, channels, num_heads, ctx_channels=None, type="self", qkv_bias=True, qk_rms_norm=False, + dtype=None, device=None, operations=None): + super().__init__() + assert channels % num_heads == 0 + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.qk_rms_norm = qk_rms_norm + if self._type == "self": + self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device) + else: + self.to_q = operations.Linear(channels, channels, bias=qkv_bias, dtype=dtype, device=device) + self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, dtype=dtype, device=device) + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device) + self.to_out = operations.Linear(channels, channels, dtype=dtype, device=device) + + def forward(self, x, context=None): + B, L, C = x.shape + if self._type == "self": + q, k, v = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1).unbind(dim=2) + else: + Lkv = context.shape[1] + q = self.to_q(x).reshape(B, L, self.num_heads, -1) + k, v = self.to_kv(context).reshape(B, Lkv, 2, self.num_heads, -1).unbind(dim=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + h = attention(q, k, v) + return self.to_out(h.reshape(B, L, -1)) + + +# Octree probability decoder + +class LevelEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256, max_period=1024, + dtype=None, device=None, operations=None): + super().__init__() + self.mlp = nn.Sequential( + operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + + @staticmethod + def level_embedding(t, dim, max_period=1024): + half = dim // 2 + freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] * 2 * torch.pi + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + emb = self.level_embedding(t, self.frequency_embedding_size, self.max_period) + return self.mlp(emb.to(self.mlp[0].weight.dtype)) + + +class ModulatedTransformerCrossOnlyBlock(nn.Module): + def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, share_mod=False, + qk_rms_norm_cross=True, qkv_bias=True, dtype=None, device=None, operations=None): + super().__init__() + self.share_mod = share_mod + self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, + type="cross", qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations) + self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)) + + def forward(self, x, mod, context): + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1)) + x = torch.addcmul(x, self.cross_attn(h, context), gate_msa.unsqueeze(1)) + h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1)) + x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1)) + return x + + +class OctreeProbabilityFixedlenDecoder(nn.Module): + # Cross-attention transformer over octree coords -> per-node 8-way child occupancy logits. + def __init__(self, model_channels=1024, cond_channels=16, num_blocks=4, num_heads=16, + num_head_channels=64, mlp_ratio=4.0, share_mod=True, + qk_rms_norm_cross=True, dtype=None, device=None, operations=None): + super().__init__() + self.model_channels = model_channels + self.cond_channels = cond_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.share_mod = share_mod + self.qk_rms_norm_cross = qk_rms_norm_cross + self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device) + self.l_embedder = LevelEmbedder(model_channels, dtype=dtype, device=device, operations=operations) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, dtype=dtype, device=device)) + if cond_channels is not None: + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossOnlyBlock( + model_channels, ctx_channels=cond_channels, num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, qk_rms_norm_cross=self.qk_rms_norm_cross, + share_mod=self.share_mod, dtype=dtype, device=device, operations=operations) + for _ in range(num_blocks) + ]) + self.out_proj = operations.Linear(model_channels, 8, dtype=dtype, device=device) + self.in_proj = operations.Linear(3, model_channels, dtype=dtype, device=device) + self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2") + + def forward(self, x, l, cond): + d = next(self.parameters()).dtype + B, L, _ = x.shape + h = self.in_proj(x.to(d)) + self.pos_embedder(x.reshape(-1, 3)).reshape(B, L, -1).to(d) + h = self.input_layer(h) + l_emb = self.l_embedder(l) + if self.share_mod: + l_emb = self.adaLN_modulation(l_emb) + cond = cond.to(d) + for block in self.blocks: + h = block(h, l_emb, cond) + h = F.layer_norm(h.float(), h.shape[-1:]).to(d) + logits = self.out_proj(h) + return {"logits": logits, "probs": torch.softmax(logits, dim=-1)} + + @staticmethod + def sample(model, cond, num_points, level, temperature=1.0, generator=None): + B = cond.shape[0] + device = cond.device + child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]], + dtype=torch.long, device=device) + prev_coords_int = torch.zeros(B, 1, 3, dtype=torch.long, device=device) + prev_counts = torch.full((B, 1), num_points, dtype=torch.long, device=device) + prev_log_probs = torch.zeros(B, 1, dtype=torch.float32, device=device) + batch_indices_range = torch.arange(B, device=device).unsqueeze(1) + + for lv in range(1, level + 1): + res_p = 1 << (lv - 1) + res = 1 << lv + parent_coords_norm = (prev_coords_int.to(torch.float32) + 0.5) / res_p + res_tensor = torch.full((B,), res, dtype=torch.long, device=device) + pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature + pred_probs = torch.softmax(pred_logits, dim=-1) + pred_log_probs = torch.log_softmax(pred_logits, dim=-1) + sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2) + pred_log_probs = pred_log_probs.flatten(1, 2) + prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1) + child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2) + mask = sampled > 0 + max_valid = mask.sum(dim=1).max().item() + scatter_indices = mask.cumsum(dim=1) - 1 + valid_scatter_indices = scatter_indices[mask] + valid_batch_indices = batch_indices_range.expand_as(mask)[mask] + next_prev_coords_int = torch.zeros(B, max_valid, 3, dtype=child_coords_int.dtype, device=device) + next_prev_coords_int[valid_batch_indices, valid_scatter_indices] = child_coords_int[mask] + next_prev_counts = torch.zeros(B, max_valid, dtype=sampled.dtype, device=device) + next_prev_counts[valid_batch_indices, valid_scatter_indices] = sampled[mask] + next_prev_log_probs = torch.zeros(B, max_valid, dtype=prev_log_probs.dtype, device=device) + next_prev_log_probs[valid_batch_indices, valid_scatter_indices] = (prev_log_probs_expanded + pred_log_probs)[mask] + prev_coords_int = next_prev_coords_int + prev_counts = next_prev_counts + prev_log_probs = next_prev_log_probs + + res = 1 << level + prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points) + coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1) + rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device) + coords_norm = (coords_int.to(torch.float32) + rand) / res + return {"points": coords_norm, "log_probs": prev_log_probs} + + +# Elastic gaussian decoder + +class TransformerCrossBlock(nn.Module): + def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, + qk_rms_norm=True, qk_rms_norm_cross=True, qkv_bias=True, + dtype=None, device=None, operations=None): + super().__init__() + self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) + self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.self_attn = MultiHeadAttention(channels, num_heads=num_heads, type="self", qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm, dtype=dtype, device=device, operations=operations) + self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, type="cross", + qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations) + self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations) + + def forward(self, x, context): + x = x + self.self_attn(self.norm1(x)) + x = x + self.cross_attn(self.norm2(x), context) + x = x + self.mlp(self.norm3(x)) + return x + + +class ElasticGaussianFixedlenDecoder(nn.Module): + # Cross-attention transformer over sampled octree points -> per-point gaussian params. + def __init__(self, in_channels=3, model_channels=1024, cond_channels=16, num_blocks=16, num_heads=16, + num_head_channels=64, mlp_ratio=4.0, *, representation_config=None, + qk_rms_norm=True, qk_rms_norm_cross=True, dtype=None, device=None, operations=None): + super().__init__() + self.rep_config = representation_config or dict( + lr=dict(_xyz=1.0, _features_dc=1.0, _opacity=1.0, _scaling=1.0, _rotation=0.1), + perturb_offset=True, perturbe_size=1.5, offset_scale=0.05, num_gaussians=32, + filter_kernel_size_3d=0.0009, scaling_bias=0.004, opacity_bias=0.1, + scaling_activation="softplus", + ) + self.out_channels = self._calc_layout() + self.model_channels = model_channels + self.cond_channels = cond_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device) + if cond_channels is not None: + self.blocks = nn.ModuleList([ + TransformerCrossBlock(model_channels, ctx_channels=cond_channels, + num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, + qk_rms_norm=qk_rms_norm, qk_rms_norm_cross=qk_rms_norm_cross, + dtype=dtype, device=device, operations=operations) + for _ in range(num_blocks) + ]) + self.in_proj = operations.Linear(in_channels, model_channels, dtype=dtype, device=device) + self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2") + self.out_proj = operations.Linear(model_channels, self.out_channels, dtype=dtype, device=device) + self._build_perturbation() + + def _calc_layout(self): + ng = self.rep_config['num_gaussians'] + self.layout = { + '_xyz': {'shape': (ng, 3), 'size': ng * 3}, + '_features_dc': {'shape': (ng, 1, 3), 'size': ng * 3}, + '_scaling': {'shape': (ng, 3), 'size': ng * 3}, + '_rotation': {'shape': (ng, 4), 'size': ng * 4}, + '_opacity': {'shape': (ng, 1), 'size': ng}, + } + self.layout['_offset_scale'] = {'shape': (ng, 1), 'size': ng} + start = 0 + for k, v in self.layout.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + return start + + def _build_perturbation(self): + ng = self.rep_config['num_gaussians'] + perturbation = torch.tensor([hammersley_sequence(3, i, ng) for i in range(ng)]).float() + perturbation = torch.atanh((perturbation * 2 - 1) / self.rep_config['perturbe_size']) + self.register_buffer('points_offset_perturbation', perturbation) + base = torch.tensor(self.rep_config['offset_scale']) + self.register_buffer('base_offset_scale', torch.log(torch.exp(base) - 1.0)) + + def _get_offset(self, h): + B = h.shape[0] + r = self.layout['_offset_scale']['range'] + _offset_scale = F.softplus( + h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_offset_scale']['shape']) + + comfy.model_management.cast_to(self.base_offset_scale, h.dtype, h.device)) + + r = self.layout['_xyz']['range'] + offset = h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_xyz']['shape']) + offset = offset * self.rep_config['lr']['_xyz'] + if self.rep_config['perturb_offset']: + offset = offset + comfy.model_management.cast_to(self.points_offset_perturbation, offset.dtype, offset.device) + offset = torch.tanh(offset) * 0.5 * self.rep_config['perturbe_size'] + offset = offset * _offset_scale + return offset + + def forward(self, x=None, cond=None): + pcd = x["points"] + d = next(self.parameters()).dtype + B, L, _ = pcd.shape + h = self.in_proj(pcd.to(d)) + self.pos_embedder(pcd.reshape(-1, 3)).reshape(B, L, -1).to(d) + h = self.input_layer(h) + cond = cond.to(d) + for block in self.blocks: + h = block(h, cond) + h = F.layer_norm(h.float(), h.shape[-1:]).to(h.dtype) + return {"features": self.out_proj(h)} + + +# Combined octree gaussian decoder (comfy first-stage model) + +class OctreeGaussianDecoder(nn.Module): + _MAX_VOXEL_LEVEL = 8 + + def __init__(self, dtype=None, device=None, operations=None): + super().__init__() + if operations is None: + operations = comfy.ops.disable_weight_init + self.octree = OctreeProbabilityFixedlenDecoder(dtype=dtype, device=device, operations=operations) + self.gs = ElasticGaussianFixedlenDecoder(dtype=dtype, device=device, operations=operations) + + @property + def gaussians_per_point(self) -> int: + return self.gs.rep_config['num_gaussians'] + + def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None): + # level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews. + # generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG. + level = self._MAX_VOXEL_LEVEL if level is None else level + num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point) + points_pred = OctreeProbabilityFixedlenDecoder.sample( + self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator, + ) + pred = self.gs(x=points_pred, cond=latent) + return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item diff --git a/comfy/model_base.py b/comfy/model_base.py index 205178911..3e2d4e930 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -46,6 +46,7 @@ import comfy.ldm.wan.model_animate import comfy.ldm.wan.ar_model import comfy.ldm.wan.model_wandancer import comfy.ldm.hunyuan3d.model +import comfy.ldm.triposplat.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model import comfy.ldm.chroma_radiance.model @@ -1806,6 +1807,24 @@ class Hunyuan3Dv2_1(BaseModel): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out +class TripoSplat(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.triposplat.model.LatentSeqMMFlowModel) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) # DINOv3 token sequence -> cross-attention context. + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + ref_latents = kwargs.get("reference_latents", None) # Flux2 VAE image latent -> additive second conditioning. + if ref_latents is not None: + out['ref_latents'] = comfy.conds.CONDList(list(ref_latents)) + latent_shapes = kwargs.get("latent_shapes", None) # {latent, camera} nested latent + if latent_shapes is not None: + out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) + return out + + class HiDream(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index f0db7d388..73354b0d2 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -676,6 +676,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys return dit_config + if '{}cam_out_layer.weight'.format(key_prefix) in state_dict_keys and '{}repo_layers.0.final_map.weight'.format(key_prefix) in state_dict_keys: # TripoSplat + return {"image_model": "triposplat"} + if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1 return {"image_model": "hidream_o1"} diff --git a/comfy/sd.py b/comfy/sd.py index 30b877b85..9a2d31930 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae +import comfy.ldm.triposplat.vae import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.cogvideo.vae import comfy.ldm.hunyuan_video.vae @@ -894,6 +895,16 @@ class VAE: #Force cast it for --disable-dynamic-vram users until there is a true core fix. if not comfy.memory_management.aimdo_enabled: self.disable_offload = True + elif "gs.base_offset_scale" in sd and "octree.out_proj.weight" in sd: # TripoSplat octree gaussian decoder + self.first_stage_model = comfy.ldm.triposplat.vae.OctreeGaussianDecoder() + self.latent_channels = 16 + self.latent_dim = 1 + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + # The generic VAE.encode/decode path isn't used: VAEDecodeTripoSplat calls the gaussian + # decoder directly (structured GaussianSplat objects, not a tensor and reserves VRAM itself from num_gaussians. + def _no_generic_io(*args, **kwargs): + raise RuntimeError("TripoSplat gaussian decoder: use the 'TripoSplat Decode' (VAEDecodeTripoSplat)") + self.memory_used_encode = self.memory_used_decode = _no_generic_io else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 00941da53..0872b0e27 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1538,6 +1538,30 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2): latent_format = latent_formats.Hunyuan3Dv2mini +class TripoSplat(supported_models_base.BASE): + # Image -> 3D gaussian splat flow denoiser + unet_config = { + "image_model": "triposplat", + } + + unet_extra_config = {} + + sampling_settings = { + "shift": 3.0, + } + + memory_usage_factor = 0.6 + + latent_format = latent_formats.TripoSplat + + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.TripoSplat(self, device=device) + + def clip_target(self, state_dict={}): + return None + class HiDream(supported_models_base.BASE): unet_config = { "image_model": "hidream", @@ -2200,6 +2224,7 @@ models = [ Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, + TripoSplat, HiDream, HiDreamO1, Chroma, diff --git a/comfy_extras/nodes_gaussian_splat.py b/comfy_extras/nodes_gaussian_splat.py index 7fb878b8b..2ba3a3820 100644 --- a/comfy_extras/nodes_gaussian_splat.py +++ b/comfy_extras/nodes_gaussian_splat.py @@ -968,7 +968,8 @@ class RenderSplat(IO.ComfyNode): bg = _hex_to_rgb(background) bg_imgs = None if bg_image is not None: # resize the plate(s) to the render size: (B,H,W,3) - bi = comfy.utils.common_upscale(bg_image.movedim(-1, 1), width, height, "bicubic", "disabled") + bi = bg_image[... , :3].movedim(-1, 1) # (B,3,H,W) + bi = comfy.utils.common_upscale(bi, width, height, "bicubic", "disabled") bg_imgs = bi.movedim(1, -1).clamp(0, 1) n_frames = abs(int(frames)) or 1 # magnitude = frame count (0 -> single still) orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction diff --git a/comfy_extras/nodes_triposplat.py b/comfy_extras/nodes_triposplat.py new file mode 100644 index 000000000..021b669fd --- /dev/null +++ b/comfy_extras/nodes_triposplat.py @@ -0,0 +1,269 @@ +# TripoSplat nodes: image -> 3D gaussian splat + +import logging + +import torch +import torch.nn.functional as F +from typing_extensions import override + +import comfy.model_management +import comfy.nested_tensor +import comfy.patcher_extension +import comfy.utils +from comfy_api.latest import ComfyExtension, IO, Types + + +_Q_TOKEN_LENGTH = 8192 +_LATENT_CHANNELS = 16 +_CAM_CHANNELS = 5 +_DINOV3_MEAN = [0.485, 0.456, 0.406] +_DINOV3_STD = [0.229, 0.224, 0.225] +_NUM_GAUSSIANS_MIN = 32768 +_NUM_GAUSSIANS_MAX = 1048576 + + +def _preprocess(image: torch.Tensor, mask: torch.Tensor, erode_radius: int, size: int) -> torch.Tensor: + # Match original preprocessing: + # resize min side to `size` -> erode alpha -> alpha bbox -> 1.2x square crop -> resize -> composite on black. + rgb = image[..., :3].clamp(0, 1).movedim(-1, 0) # (3, H, W) + alpha = mask.clamp(0, 1)[None] # (1, H, W) + rgba = torch.cat([rgb, alpha], 0)[None] # (1, 4, H, W) + + h, w = rgba.shape[-2:] + s = size / min(w, h) + rgba = comfy.utils.common_upscale(rgba, max(1, round(w * s)), max(1, round(h * s)), "lanczos", "disabled").clamp(0, 1) + + a = rgba[:, 3:4] + if erode_radius > 0: + # min filter over a (2r+1) window == morphological erosion of the alpha matte. + a = -F.max_pool2d(-a, 2 * erode_radius + 1, stride=1, padding=erode_radius) + rgba = torch.cat([rgba[:, :3], a], 1) + + ys, xs = torch.nonzero(a[0, 0] > 0, as_tuple=True) + if xs.numel() == 0: + raise ValueError("TripoSplatPreprocessImage: mask is empty (no foreground pixels).") + x0, x1 = int(xs.min()), int(xs.max()) + y0, y1 = int(ys.min()), int(ys.max()) + cx, cy = (x0 + x1) / 2, (y0 + y1) / 2 + half = max(x1 - x0, y1 - y0) / 2 * 1.2 + left, upper, right, lower = int(cx - half), int(cy - half), int(cx + half), int(cy + half) + + H, W = rgba.shape[-2:] + crop = rgba.new_zeros((1, 4, lower - upper, right - left)) # out-of-bounds stays 0, matching PIL.crop + sx0, sy0, sx1, sy1 = max(left, 0), max(upper, 0), min(right, W), min(lower, H) + if sx1 > sx0 and sy1 > sy0: + crop[:, :, sy0 - upper:sy1 - upper, sx0 - left:sx1 - left] = rgba[:, :, sy0:sy1, sx0:sx1] + + crop = comfy.utils.common_upscale(crop, size, size, "lanczos", "disabled").clamp(0, 1) + out = (crop[:, :3] * crop[:, 3:4])[0].movedim(0, -1) # composite over black == rgb * alpha + return out.unsqueeze(0) # (1, 1024, 1024, 3) + + +class TripoSplatPreprocessImage(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoSplatPreprocessImage", + display_name="TripoSplat Preprocess Image", + category="3d/conditioning", + description="Crop center each image to a square canvas on a black background and add padding.", + inputs=[ + IO.Image.Input("image"), + IO.Mask.Input("mask"), + IO.Int.Input("erode_radius", default=1, min=0, max=16, + tooltip="Erode the alpha matte by this pixel radius before cropping (avoids border bleed)."), + IO.Int.Input("size", default=1024, min=256, max=4096, step=16, + tooltip="Square image size. The model is trained at 1024; other sizes run but are off-distribution."), + ], + outputs=[IO.Image.Output(display_name="image")], + ) + + @classmethod + def execute(cls, image, mask, erode_radius, size) -> IO.NodeOutput: + size = max(16, (int(size) // 16) * 16) # DINOv3 patch / Flux2 VAE stride is 16 + if mask.shape[0] != image.shape[0]: + mask = comfy.utils.repeat_to_batch_size(mask, image.shape[0]) + if tuple(mask.shape[1:]) != tuple(image.shape[1:3]): + mask = F.interpolate(mask[:, None].float(), size=tuple(image.shape[1:3]), mode="bilinear", align_corners=False)[:, 0] + prepared = torch.cat([_preprocess(image[i], mask[i], erode_radius, size) for i in range(image.shape[0])], dim=0) + return IO.NodeOutput(prepared) + + +class TripoSplatConditioning(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoSplatConditioning", + display_name="TripoSplat Conditioning", + category="3d/conditioning", + description="Encode the image with DINOv3 and the Flux2 VAE into TripoSplat positive/negative " + "conditioning, and create the fixed size noise target (latent + camera) for the KSampler", + inputs=[ + IO.ClipVision.Input("clip_vision", tooltip="DINOv3 ViT-H/16+ image encoder"), + IO.Vae.Input("vae", tooltip="Flux2 VAE"), + IO.Image.Input("image"), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + IO.Latent.Output(display_name="latent", tooltip="The fixed size noise target (latent +camera)."), + ], + ) + + @classmethod + def execute(cls, clip_vision, vae, image) -> IO.NodeOutput: + # feature1: DINOv3 token sequence (cls + registers + patches), ImageNet-normalized, with a final non-affine layer norm on top + comfy.model_management.load_model_gpu(clip_vision.patcher) + device = clip_vision.load_device + model_dtype = next(clip_vision.model.parameters()).dtype + img = image.movedim(-1, 1).to(device) # (B,3,H,W) in [0,1] + mean = torch.tensor(_DINOV3_MEAN, device=device).view(1, 3, 1, 1) + std = torch.tensor(_DINOV3_STD, device=device).view(1, 3, 1, 1) + img = (img - mean) / std + seq = clip_vision.model(pixel_values=img.to(model_dtype))[0] + feature1 = F.layer_norm(seq.float(), seq.shape[-1:]).to(comfy.model_management.intermediate_device()) + + # Second conditioning: the Flux2 VAE latent of the image, carried as a standard reference_latents entry + ref = vae.encode(image).to(comfy.model_management.intermediate_device()) # (B, 128, H, W) + b = ref.shape[0] + + positive = [[feature1, {"reference_latents": [ref]}]] + negative = [[torch.zeros_like(feature1), {"reference_latents": [torch.zeros_like(ref)]}]] + + # Fixed noise target: the latent is a constant-shape (8192, 16) shape-code + a (1, 5) camera token + dev = comfy.model_management.intermediate_device() + latent_seq = torch.zeros([b, _Q_TOKEN_LENGTH, _LATENT_CHANNELS], device=dev) + camera = torch.zeros([b, 1, _CAM_CHANNELS], device=dev) + samples = comfy.nested_tensor.NestedTensor((latent_seq, camera)) + return IO.NodeOutput(positive, negative, {"samples": samples}) + + +class VAEDecodeTripoSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeTripoSplat", + display_name="TripoSplat Decode", + category="3d/latent", + description="Decode the sampled TripoSplat latent into a 3D gaussian splat. " + "Modify the number of gaussians to vary the density.", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"), + IO.Int.Input("num_gaussians", default=262144, min=_NUM_GAUSSIANS_MIN, max=_NUM_GAUSSIANS_MAX, step=32, + tooltip="Number of gaussians to produce (rounded to a multiple of 32). " + "262144 matches the octree's point density; higher oversamples the same points " + "(denser, but no new detail) and costs proportionally more VRAM/time."), + IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, + tooltip="Seeds the octree point sampler (global RNG) for deterministic decodes."), + ], + outputs=[IO.Splat.Output(display_name="splat")], + ) + + @classmethod + def execute(cls, samples, vae, num_gaussians, seed) -> IO.NodeOutput: + s = samples["samples"] + latent = s.unbind()[0] if getattr(s, "is_nested", False) else s # take the latent stream, drop camera + + decoder = vae.first_stage_model + gpp = decoder.gaussians_per_point + n = max(_NUM_GAUSSIANS_MIN, min(_NUM_GAUSSIANS_MAX, int(num_gaussians))) + if n % gpp != 0: + n = round(n / gpp) * gpp + + dtype_size = comfy.model_management.dtype_size(vae.vae_dtype) + hidden = decoder.gs.model_channels + cond_tokens = latent.shape[1] + memory_required = (cond_tokens * 4 + (n // gpp) * 10) * hidden * dtype_size + comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required) + latent = latent.to(device=vae.device, dtype=vae.vae_dtype) + generator = torch.Generator(device="cpu").manual_seed(seed) + parts = [g.render_tensors() for g in decoder.decode(latent, num_gaussians=n, generator=generator)] + positions, scales, rotations, opacities, sh = (torch.stack(t) for t in zip(*parts)) + return IO.NodeOutput(Types.SPLAT(positions, scales, rotations, opacities, sh)) + + +class TripoSplatSamplingPreview(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoSplatSamplingPreview", + display_name="TripoSplat Sampling Preview", + category="3d/latent", + description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded " + "gaussian splat preview at each step.", + inputs=[ + IO.Model.Input("model"), + IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"), + IO.Int.Input("octree_level", default=5, min=2, max=8, advanced=True, + tooltip="Octree depth for the preview decode (lower = cheaper/coarser)."), + IO.Int.Input("num_gaussians", default=16384, min=1024, max=262144, step=32, + tooltip="Number of gaussians to produce for the preview (rounded to a multiple of 32)."), + IO.Float.Input("yaw", default=90.0, min=-360.0, max=360.0, step=1.0, tooltip="Preview camera yaw in degrees.", advanced=True,), + IO.Float.Input("pitch", default=15.0, min=-89.0, max=89.0, step=1.0, tooltip="Preview camera pitch in degrees.", advanced=True,), + IO.Int.Input("point_size", default=3, min=1, max=16, + tooltip="Maximum splat radius in pixels. Each gaussian is sized from its scale and capped here; " + "lower = finer/pointier, higher = chunkier."), + ], + outputs=[IO.Model.Output()], + ) + + @classmethod + def execute(cls, model, vae, octree_level, num_gaussians, yaw, pitch, point_size) -> IO.NodeOutput: + from comfy.ldm.triposplat.preview import decode_x0_to_image + cfg = {"gaussians": num_gaussians, "level": octree_level, "yaw": yaw, "pitch": pitch, + "point_size": point_size} + + fsm = vae.first_stage_model + cond_tokens = model.model.diffusion_model.q_token_length + memory_required = (cond_tokens * 4 + (num_gaussians // fsm.gaussians_per_point) * 10) * fsm.gs.model_channels * comfy.model_management.dtype_size(vae.vae_dtype) + + # Live preview via WrappersMP.OUTER_SAMPLE + ProgressBar + # The wrapper augments the sampler's own callback to decode x0 -> gaussian splat -> preview image each step + def outer_sample_wrapper(executor, *args, **kwargs): + args = list(args) + cb_idx = 5 # outer_sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + orig_cb = args[cb_idx] if len(args) > cb_idx else kwargs.get("callback") + state = {"ok": True, "pbar": None, "loaded": False} + + def callback(step, x0, x, total_steps): + if orig_cb is not None: + orig_cb(step, x0, x, total_steps) + if not state["ok"]: + return + try: + if not state["loaded"]: + comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required) + state["loaded"] = True + img = decode_x0_to_image(vae, x0, cfg) + if state["pbar"] is None: + state["pbar"] = comfy.utils.ProgressBar(total_steps) + state["pbar"].update_absolute(step + 1, total_steps, ("JPEG", img, 512)) + except Exception as e: + logging.warning("TripoSplatSamplingPreview: preview failed, disabling ({})".format(e)) + state["ok"] = False + + if len(args) > cb_idx: + args[cb_idx] = callback + else: + kwargs["callback"] = callback + return executor(*args, **kwargs) + + m = model.clone() + m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "triposplat_sampling_preview", outer_sample_wrapper) + return IO.NodeOutput(m) + + +class TripoSplatExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TripoSplatPreprocessImage, + TripoSplatConditioning, + VAEDecodeTripoSplat, + TripoSplatSamplingPreview, + ] + + +async def comfy_entrypoint() -> TripoSplatExtension: + return TripoSplatExtension() diff --git a/nodes.py b/nodes.py index 5678bc22d..331425b87 100644 --- a/nodes.py +++ b/nodes.py @@ -2456,6 +2456,7 @@ async def init_builtin_extra_nodes(): "nodes_moge.py", "nodes_mediapipe.py", "nodes_gaussian_splat.py", + "nodes_triposplat.py" ] import_failed = []