mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-02 12:27:59 +08:00
feat: Add TripoSplat support (#14210)
This commit is contained in:
parent
70a2e1a851
commit
462c27fdb2
@ -2,6 +2,7 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_repl
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import torch
|
||||
|
||||
import comfy.ops
|
||||
import comfy.model_patcher
|
||||
@ -9,6 +10,7 @@ import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.clip_model
|
||||
import comfy.image_encoders.dino2
|
||||
import comfy.image_encoders.dino3
|
||||
|
||||
class Output:
|
||||
def __getitem__(self, key):
|
||||
@ -23,12 +25,16 @@ IMAGE_ENCODERS = {
|
||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
"dinov3": comfy.image_encoders.dino3.DINOv3ViTModel,
|
||||
}
|
||||
|
||||
class ClipVisionModel():
|
||||
def __init__(self, json_config):
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
if isinstance(json_config, dict):
|
||||
config = json_config
|
||||
else:
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
@ -44,6 +50,10 @@ class ClipVisionModel():
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
if self.model_type == "dinov3" and self.dtype == torch.float16:
|
||||
# DINOv3's activations borderline fits fp16, preferring bf16 if available for better stability #TODO: further fp16 tests in practice
|
||||
if comfy.model_management.should_use_bf16(self.load_device, prioritize_performance=True):
|
||||
self.dtype = torch.bfloat16
|
||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
@ -134,6 +144,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
|
||||
elif 'layer.0.mlp.gate_proj.weight' in sd and 'layer.31.norm1.weight' in sd: # Dinov3 ViT-H/16+ (SwiGLU gated MLP, 32 layers)
|
||||
json_config = comfy.image_encoders.dino3.DINOV3_VITH_CONFIG
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
260
comfy/image_encoders/dino3.py
Normal file
260
comfy/image_encoders/dino3.py
Normal file
@ -0,0 +1,260 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
|
||||
|
||||
|
||||
# DINOv3 ViT-H/16+ (SwiGLU)
|
||||
DINOV3_VITH_CONFIG = {
|
||||
"model_type": "dinov3",
|
||||
"num_hidden_layers": 32,
|
||||
"hidden_size": 1280,
|
||||
"num_attention_heads": 20,
|
||||
"num_register_tokens": 4,
|
||||
"intermediate_size": 5120,
|
||||
"layer_norm_eps": 1e-5,
|
||||
"num_channels": 3,
|
||||
"patch_size": 16,
|
||||
"rope_theta": 100.0,
|
||||
"use_gated_mlp": True,
|
||||
"gated_mlp_act": "silu",
|
||||
"image_size": 1024,
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225],
|
||||
}
|
||||
|
||||
|
||||
class DINOv3ViTMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.act_fn = torch.nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.up_proj(x)))
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, **kwargs):
|
||||
num_tokens = q.shape[-2]
|
||||
num_patches = sin.shape[-2]
|
||||
num_prefix_tokens = num_tokens - num_patches
|
||||
|
||||
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
|
||||
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
|
||||
|
||||
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
|
||||
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
|
||||
|
||||
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
|
||||
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
|
||||
|
||||
return q, k
|
||||
|
||||
|
||||
class DINOv3ViTAttention(nn.Module):
|
||||
def __init__(self, hidden_size, num_attention_heads, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.embed_dim = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
|
||||
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, position_embeddings=None, **kwargs):
|
||||
batch_size, patches, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
attn = optimized_attention_for_device(query_states.device, mask=False)
|
||||
attn_output = attn(
|
||||
query_states, key_states, value_states, self.num_heads, attention_mask,
|
||||
skip_reshape=True, skip_output_reshape=True, low_precision_attention=False,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class DINOv3ViTGatedMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations, act="silu"):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.act_fn = torch.nn.SiLU() if act == "silu" else torch.nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
def get_patches_center_coordinates(num_patches_h, num_patches_w, dtype, device):
|
||||
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
|
||||
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
|
||||
coords_h = coords_h / num_patches_h
|
||||
coords_w = coords_w / num_patches_w
|
||||
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
|
||||
coords = coords.flatten(0, 1)
|
||||
coords = 2.0 * coords - 1.0
|
||||
return coords
|
||||
|
||||
|
||||
class DINOv3ViTRopePositionEmbedding(nn.Module):
|
||||
inv_freq: torch.Tensor
|
||||
|
||||
def __init__(self, rope_theta, hidden_size, num_attention_heads, patch_size, device, dtype):
|
||||
super().__init__()
|
||||
self.base = rope_theta
|
||||
self.head_dim = hidden_size // num_attention_heads
|
||||
self.patch_size = patch_size
|
||||
|
||||
inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
_, _, height, width = pixel_values.shape
|
||||
num_patches_h = height // self.patch_size
|
||||
num_patches_w = width // self.patch_size
|
||||
|
||||
patch_coords = get_patches_center_coordinates(num_patches_h, num_patches_w, dtype=torch.float32, device=pixel_values.device)
|
||||
self.inv_freq = self.inv_freq.to(pixel_values.device)
|
||||
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
|
||||
angles = angles.flatten(1, 2)
|
||||
angles = angles.tile(2)
|
||||
cos = torch.cos(angles).to(dtype=pixel_values.dtype)
|
||||
sin = torch.sin(angles).to(dtype=pixel_values.dtype)
|
||||
return cos, sin
|
||||
|
||||
|
||||
class DINOv3ViTEmbeddings(nn.Module):
|
||||
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
|
||||
self.mask_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
|
||||
self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype))
|
||||
self.patch_embeddings = operations.Conv2d(
|
||||
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, pixel_values, bool_masked_pos=None):
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embeddings.weight.dtype
|
||||
|
||||
patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
|
||||
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
|
||||
|
||||
if bool_masked_pos is not None:
|
||||
mask_token = self.mask_token.to(patch_embeddings.dtype)
|
||||
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
|
||||
|
||||
cls_token = self.cls_token.expand(batch_size, -1, -1).to(patch_embeddings.device)
|
||||
register_tokens = self.register_tokens.expand(batch_size, -1, -1).to(patch_embeddings.device)
|
||||
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
|
||||
return embeddings
|
||||
|
||||
|
||||
class DINOv3ViTLayer(nn.Module):
|
||||
def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size,
|
||||
num_attention_heads, device, dtype, operations, gated_mlp_act="silu"):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||
self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
|
||||
|
||||
self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||
if use_gated_mlp:
|
||||
self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations, act=gated_mlp_act)
|
||||
else:
|
||||
self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations)
|
||||
self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.attention(hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings)
|
||||
hidden_states = self.layer_scale1(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.layer_scale2(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DINOv3ViTModel(nn.Module):
|
||||
def __init__(self, config, dtype, device, operations):
|
||||
super().__init__()
|
||||
num_hidden_layers = config["num_hidden_layers"]
|
||||
hidden_size = config["hidden_size"]
|
||||
num_attention_heads = config["num_attention_heads"]
|
||||
num_register_tokens = config["num_register_tokens"]
|
||||
intermediate_size = config["intermediate_size"]
|
||||
layer_norm_eps = config["layer_norm_eps"]
|
||||
num_channels = config["num_channels"]
|
||||
patch_size = config["patch_size"]
|
||||
rope_theta = config["rope_theta"]
|
||||
use_gated_mlp = config.get("use_gated_mlp", False)
|
||||
gated_mlp_act = config.get("gated_mlp_act", "silu")
|
||||
|
||||
self.embeddings = DINOv3ViTEmbeddings(
|
||||
hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(
|
||||
rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device
|
||||
)
|
||||
self.layer = nn.ModuleList([
|
||||
DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=use_gated_mlp, mlp_bias=True,
|
||||
intermediate_size=intermediate_size, num_attention_heads=num_attention_heads,
|
||||
dtype=dtype, device=device, operations=operations, gated_mlp_act=gated_mlp_act)
|
||||
for _ in range(num_hidden_layers)])
|
||||
self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
def forward(self, pixel_values, bool_masked_pos=None, **kwargs):
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
for layer_module in self.layer:
|
||||
hidden_states = layer_module(hidden_states, position_embeddings=position_embeddings)
|
||||
|
||||
if kwargs.get("skip_norm_elementwise", False):
|
||||
sequence_output = F.layer_norm(hidden_states, hidden_states.shape[-1:])
|
||||
else:
|
||||
norm = self.norm.to(hidden_states.device)
|
||||
sequence_output = norm(hidden_states)
|
||||
pooled_output = sequence_output[:, 0, :]
|
||||
return sequence_output, None, pooled_output, None
|
||||
@ -239,6 +239,16 @@ class Flux2(LatentFormat):
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
class TripoSplat(LatentFormat):
|
||||
# Sequence latent (B, 8192, 16) the camera token rides alongside as a second nested latent
|
||||
latent_channels = 16
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
latent_dimensions = 3
|
||||
|
||||
199
comfy/ldm/triposplat/gaussian.py
Normal file
199
comfy/ldm/triposplat/gaussian.py
Normal file
@ -0,0 +1,199 @@
|
||||
# TripoSplat 3D gaussian container. Operates on already-decoded
|
||||
# tensors and exposes them as render-ready tensors (render_tensors) for the generic SPLAT type.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class GaussianModel:
|
||||
def __init__(self, aabb: list, sh_degree: int = 0, mininum_kernel_size: float = 0.0,
|
||||
scaling_bias: float = 0.01, opacity_bias: float = 0.1,
|
||||
scaling_activation: str = "exp", device=None):
|
||||
self.sh_degree = sh_degree
|
||||
self.mininum_kernel_size = mininum_kernel_size
|
||||
self.scaling_bias = scaling_bias
|
||||
self.opacity_bias = opacity_bias
|
||||
self.device = device
|
||||
self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
|
||||
|
||||
if scaling_activation == "exp":
|
||||
self._scaling_activation = torch.exp
|
||||
self._inverse_scaling_activation = torch.log
|
||||
elif scaling_activation == "softplus":
|
||||
self._scaling_activation = F.softplus
|
||||
self._inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x))
|
||||
|
||||
self._opacity_activation = torch.sigmoid
|
||||
self._inverse_opacity_activation = lambda x: torch.log(x / (1 - x))
|
||||
|
||||
self.scale_bias = self._inverse_scaling_activation(torch.tensor(self.scaling_bias)).to(self.device)
|
||||
self.rots_bias = torch.zeros(4, device=self.device)
|
||||
self.rots_bias[0] = 1
|
||||
self.opacity_bias_val = self._inverse_opacity_activation(torch.tensor(self.opacity_bias)).to(self.device)
|
||||
|
||||
self._storage = {}
|
||||
|
||||
def _get_store(self, name):
|
||||
return self._storage.get(name)
|
||||
|
||||
def _set_store(self, name, value):
|
||||
self._storage[name] = value
|
||||
|
||||
@property
|
||||
def _xyz(self):
|
||||
return self._get_store("_xyz")
|
||||
@_xyz.setter
|
||||
def _xyz(self, value):
|
||||
if value is None:
|
||||
self._set_store("_xyz", None)
|
||||
self._set_store("xyz", None)
|
||||
return
|
||||
self._set_store("_xyz", value)
|
||||
self._set_store("xyz", value * self.aabb[None, 3:] + self.aabb[None, :3])
|
||||
|
||||
@property
|
||||
def get_xyz(self):
|
||||
return self._get_store("xyz")
|
||||
|
||||
@property
|
||||
def _features_dc(self):
|
||||
return self._get_store("_features_dc")
|
||||
@_features_dc.setter
|
||||
def _features_dc(self, value):
|
||||
self._set_store("_features_dc", value)
|
||||
|
||||
@property
|
||||
def _opacity(self):
|
||||
return self._get_store("_opacity")
|
||||
@_opacity.setter
|
||||
def _opacity(self, value):
|
||||
if value is None:
|
||||
self._set_store("_opacity", None)
|
||||
self._set_store("opacity", None)
|
||||
return
|
||||
self._set_store("_opacity", value)
|
||||
self._set_store("opacity", self._opacity_activation(value + self.opacity_bias_val))
|
||||
|
||||
@property
|
||||
def get_opacity(self):
|
||||
return self._get_store("opacity")
|
||||
|
||||
@property
|
||||
def _scaling(self):
|
||||
return self._get_store("_scaling")
|
||||
@_scaling.setter
|
||||
def _scaling(self, value):
|
||||
if value is None:
|
||||
self._set_store("_scaling", None)
|
||||
self._set_store("scaling", None)
|
||||
return
|
||||
self._set_store("_scaling", value)
|
||||
s = self._scaling_activation(value + self.scale_bias)
|
||||
s = torch.square(s) + self.mininum_kernel_size ** 2
|
||||
self._set_store("scaling", torch.sqrt(s))
|
||||
|
||||
@property
|
||||
def get_scaling(self):
|
||||
return self._get_store("scaling")
|
||||
|
||||
@property
|
||||
def _rotation(self):
|
||||
return self._get_store("_rotation")
|
||||
@_rotation.setter
|
||||
def _rotation(self, value):
|
||||
self._set_store("_rotation", value)
|
||||
|
||||
_DEFAULT_TRANSFORM = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
|
||||
|
||||
def render_tensors(self):
|
||||
# Render-ready (activated, world-space) tensors for the generic SPLAT type. The axis transform
|
||||
# (a 3x3 rotation, object frame -> viewer Y-up) is baked into positions and rotations.
|
||||
# Returns float tensors on the intermediate device: positions (N,3), scales (N,3) linear,
|
||||
# rotations (N,4) wxyz, opacities (N,1) in [0,1], sh (N,K,3) coefficients.
|
||||
xyz = self.get_xyz.float()
|
||||
scaling = self.get_scaling.float()
|
||||
opacity = self.get_opacity.float()
|
||||
rotation = (self._rotation + self.rots_bias[None, :]).float()
|
||||
sh = self._features_dc.float() # (N, K, 3)
|
||||
T = torch.as_tensor(self._DEFAULT_TRANSFORM, dtype=torch.float32, device=xyz.device)
|
||||
xyz = xyz @ T.T
|
||||
rotation = _matrix_to_quat(torch.matmul(T, _quat_to_matrix(rotation)))
|
||||
rotation = rotation / torch.linalg.norm(rotation, dim=-1, keepdim=True)
|
||||
out_device = comfy.model_management.intermediate_device()
|
||||
return (
|
||||
xyz.to(out_device).contiguous(), scaling.to(out_device).contiguous(),
|
||||
rotation.to(out_device).contiguous(), opacity.to(out_device).contiguous(),
|
||||
sh.to(out_device).contiguous(),
|
||||
)
|
||||
|
||||
|
||||
def _quat_to_matrix(q):
|
||||
q = q / torch.linalg.norm(q, dim=-1, keepdim=True)
|
||||
w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
|
||||
R = torch.stack([
|
||||
1 - 2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y),
|
||||
2*(x*y + w*z), 1 - 2*(x*x + z*z), 2*(y*z - w*x),
|
||||
2*(x*z - w*y), 2*(y*z + w*x), 1 - 2*(x*x + y*y),
|
||||
], dim=-1).reshape(-1, 3, 3)
|
||||
return R
|
||||
|
||||
|
||||
def _matrix_to_quat(R):
|
||||
trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
|
||||
q = torch.zeros((R.shape[0], 4), dtype=R.dtype, device=R.device)
|
||||
s = torch.sqrt(torch.clamp(trace + 1, min=0)) * 2
|
||||
q[:, 0] = 0.25 * s
|
||||
denom = torch.where(s != 0, s, torch.ones_like(s))
|
||||
q[:, 1] = (R[:, 2, 1] - R[:, 1, 2]) / denom
|
||||
q[:, 2] = (R[:, 0, 2] - R[:, 2, 0]) / denom
|
||||
q[:, 3] = (R[:, 1, 0] - R[:, 0, 1]) / denom
|
||||
m01 = (R[:, 0, 0] >= R[:, 1, 1]) & (R[:, 0, 0] >= R[:, 2, 2]) & (s == 0)
|
||||
s1 = torch.sqrt(torch.clamp(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2], min=0)) * 2
|
||||
q[m01, 0] = (R[m01, 2, 1] - R[m01, 1, 2]) / s1[m01]
|
||||
q[m01, 1] = 0.25 * s1[m01]
|
||||
q[m01, 2] = (R[m01, 0, 1] + R[m01, 1, 0]) / s1[m01]
|
||||
q[m01, 3] = (R[m01, 0, 2] + R[m01, 2, 0]) / s1[m01]
|
||||
m11 = (R[:, 1, 1] > R[:, 0, 0]) & (R[:, 1, 1] >= R[:, 2, 2]) & (s == 0)
|
||||
s2 = torch.sqrt(torch.clamp(1 + R[:, 1, 1] - R[:, 0, 0] - R[:, 2, 2], min=0)) * 2
|
||||
q[m11, 0] = (R[m11, 0, 2] - R[m11, 2, 0]) / s2[m11]
|
||||
q[m11, 1] = (R[m11, 0, 1] + R[m11, 1, 0]) / s2[m11]
|
||||
q[m11, 2] = 0.25 * s2[m11]
|
||||
q[m11, 3] = (R[m11, 1, 2] + R[m11, 2, 1]) / s2[m11]
|
||||
m21 = (R[:, 2, 2] > R[:, 0, 0]) & (R[:, 2, 2] > R[:, 1, 1]) & (s == 0)
|
||||
s3 = torch.sqrt(torch.clamp(1 + R[:, 2, 2] - R[:, 0, 0] - R[:, 1, 1], min=0)) * 2
|
||||
q[m21, 0] = (R[m21, 1, 0] - R[m21, 0, 1]) / s3[m21]
|
||||
q[m21, 1] = (R[m21, 0, 2] + R[m21, 2, 0]) / s3[m21]
|
||||
q[m21, 2] = (R[m21, 1, 2] + R[m21, 2, 1]) / s3[m21]
|
||||
q[m21, 3] = 0.25 * s3[m21]
|
||||
return q / torch.linalg.norm(q, dim=-1, keepdim=True)
|
||||
|
||||
|
||||
def build_gaussian_models(decoder, points_pred: dict, pred: dict):
|
||||
# Assemble GaussianModels from the elastic decoder layout. decoder is the ElasticGaussianFixedlenDecoder
|
||||
# (carries layout / rep_config / _get_offset)
|
||||
x = points_pred
|
||||
offset = decoder._get_offset(pred['features'])
|
||||
h = pred["features"]
|
||||
ret = []
|
||||
for i in range(h.shape[0]):
|
||||
g = GaussianModel(
|
||||
sh_degree=0,
|
||||
aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
|
||||
mininum_kernel_size=decoder.rep_config['filter_kernel_size_3d'],
|
||||
scaling_bias=decoder.rep_config['scaling_bias'],
|
||||
opacity_bias=decoder.rep_config['opacity_bias'],
|
||||
scaling_activation=decoder.rep_config['scaling_activation'],
|
||||
device=h.device,
|
||||
)
|
||||
_x = x["points"][i, :, None, :]
|
||||
for k, v in decoder.layout.items():
|
||||
if k == '_xyz':
|
||||
setattr(g, k, (offset[i] + _x).flatten(0, 1))
|
||||
elif k in ('_xyz_center', '_offset_scale'):
|
||||
continue
|
||||
else:
|
||||
feats = h[i][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
|
||||
setattr(g, k, feats * decoder.rep_config['lr'][k])
|
||||
ret.append(g)
|
||||
return ret
|
||||
326
comfy/ldm/triposplat/model.py
Normal file
326
comfy/ldm/triposplat/model.py
Normal file
@ -0,0 +1,326 @@
|
||||
# TripoSplat flow-matching denoiser (LatentSeqMMFlowModel). Registered as a ModelType.FLOW arch and
|
||||
# driven by the standard KSampler; jointly denoises the (B, 8192, 16) latent and a (B, 1, 5) camera token
|
||||
# carried as a 2-element nested latent.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.rmsnorm
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
|
||||
|
||||
class MultiHeadRMSNorm(nn.Module):
|
||||
def __init__(self, dim, heads, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.empty(heads, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x):
|
||||
x = comfy.rmsnorm.rms_norm(x)
|
||||
return x * comfy.model_management.cast_to(self.gamma, x.dtype, x.device)
|
||||
|
||||
|
||||
# Positional embeddings
|
||||
|
||||
class RePo3DRotaryEmbedding(nn.Module):
|
||||
def __init__(self, model_channels, num_heads, head_dim, repo_hidden_ratio=0.125, max_freq=16.0,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
repo_hidden_size = int(model_channels * repo_hidden_ratio)
|
||||
self.norm = operations.LayerNorm(model_channels, dtype=dtype, device=device)
|
||||
self.gate_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device)
|
||||
self.content_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device)
|
||||
self.act = nn.SiLU()
|
||||
self.final_map = operations.Linear(repo_hidden_size, 3 * num_heads, bias=False, dtype=dtype, device=device)
|
||||
self.dim_0 = 2 * (head_dim // 6)
|
||||
self.dim_1 = 2 * (head_dim // 6)
|
||||
self.dim_2 = head_dim - self.dim_0 - self.dim_1
|
||||
dims = [self.dim_0, self.dim_1, self.dim_2]
|
||||
freqs_list = []
|
||||
for d in dims:
|
||||
freq_dim = d // 2
|
||||
freqs_list.append(torch.linspace(1.0, float(max_freq), steps=freq_dim, dtype=torch.float32))
|
||||
self.freqs_0 = nn.Parameter(freqs_list[0])
|
||||
self.freqs_1 = nn.Parameter(freqs_list[1])
|
||||
self.freqs_2 = nn.Parameter(freqs_list[2])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
h = self.norm(hidden_states)
|
||||
feat = self.act(self.gate_map(h)) * self.content_map(h)
|
||||
out = self.final_map(feat)
|
||||
B, L, _ = out.shape
|
||||
delta_pos = out.reshape(B, L, self.num_heads, 3)
|
||||
f0 = comfy.model_management.cast_to(self.freqs_0, torch.float32, out.device)
|
||||
f1 = comfy.model_management.cast_to(self.freqs_1, torch.float32, out.device)
|
||||
f2 = comfy.model_management.cast_to(self.freqs_2, torch.float32, out.device)
|
||||
ang_0 = delta_pos[..., 0].unsqueeze(-1) * f0 * torch.pi
|
||||
ang_1 = delta_pos[..., 1].unsqueeze(-1) * f1 * torch.pi
|
||||
ang_2 = delta_pos[..., 2].unsqueeze(-1) * f2 * torch.pi
|
||||
ang = torch.cat([ang_0, ang_1, ang_2], dim=-1).float() # (B, L, heads, head_dim/2)
|
||||
cos, sin = ang.cos(), ang.sin()
|
||||
return torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*ang.shape, 2, 2)
|
||||
|
||||
|
||||
class PcdAbsolutePositionEmbedder(nn.Module):
|
||||
# Sinusoidal absolute position embedding. Two fixed schedules are used in TripoSplat:
|
||||
# "pow2" (flow-model latent anchors) and "log2" (octree / gaussian decoders).
|
||||
def __init__(self, channels: int, in_channels: int = 3, max_res: int = 16, schedule: str = "pow2"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.in_channels = in_channels
|
||||
self.max_res = max_res
|
||||
self.schedule = schedule
|
||||
self.freq_dim = channels // in_channels // 2
|
||||
|
||||
def _freqs(self, device):
|
||||
if self.schedule == "pow2":
|
||||
freqs_2exp = torch.arange(self.max_res, dtype=torch.float32, device=device)
|
||||
res_dim = max(0, self.freq_dim - self.max_res)
|
||||
freqs_res = (torch.arange(res_dim, dtype=torch.float32, device=device) / max(res_dim, 1) * self.max_res
|
||||
if res_dim > 0 else torch.empty(0, device=device))
|
||||
freqs = torch.cat([freqs_2exp, freqs_res], dim=0)[:self.freq_dim]
|
||||
return torch.pow(2.0, freqs) * 2.0 # *2 folds this schedule's 2*pi into the shared *pi below
|
||||
logs = torch.linspace(0.0, float(self.max_res), steps=self.freq_dim, dtype=torch.float32, device=device)
|
||||
return torch.pow(2.0, logs)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
*dims, D = x.shape
|
||||
out = torch.outer(x.reshape(-1), self._freqs(x.device)) * torch.pi
|
||||
out = torch.cat([out.sin(), out.cos()], dim=-1).reshape(*dims, -1)
|
||||
if out.shape[-1] < self.channels:
|
||||
out = torch.cat([out, torch.zeros(*dims, self.channels - out.shape[-1],
|
||||
device=out.device, dtype=out.dtype)], dim=-1)
|
||||
return out.to(orig_dtype)
|
||||
|
||||
|
||||
def attention(q, k, v, transformer_options=None):
|
||||
# q, k, v: (B, L, heads, dim) -> (B, L, heads, dim). Shared optimized_attention call convention.
|
||||
out = optimized_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), heads=q.shape[2],
|
||||
skip_reshape=True, skip_output_reshape=True, low_precision_attention=False,
|
||||
transformer_options=transformer_options)
|
||||
return out.transpose(1, 2)
|
||||
|
||||
|
||||
# Transformer building blocks
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(in_channels, hidden_channels, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(hidden_channels, out_channels, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
class RopeMultiHeadAttention(nn.Module):
|
||||
def __init__(self, channels, num_heads, qkv_bias=True, qk_rms_norm=False, use_rope=False,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = channels // num_heads
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
self.use_rope = use_rope
|
||||
self.qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
if self.qk_rms_norm:
|
||||
self.q_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
||||
self.k_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
||||
self.out = operations.Linear(channels, channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, rope_emb=None, transformer_options=None):
|
||||
B, L, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim)
|
||||
q, k, v = qkv.unbind(2)
|
||||
if self.use_rope:
|
||||
q, k = apply_rope(q, k, rope_emb)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
h = attention(q, k, v, transformer_options) # (B, L, heads, dim)
|
||||
return self.out(h.reshape(B, L, C))
|
||||
|
||||
|
||||
class UnifiedTransformerBlock(nn.Module):
|
||||
def __init__(self, channels, num_heads, mlp_ratio=4.0,
|
||||
use_rope=False, qk_rms_norm=False, qkv_bias=True,
|
||||
modulation=True, share_mod=False,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.modulation = modulation
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = RopeMultiHeadAttention(channels, num_heads=num_heads,
|
||||
qkv_bias=qkv_bias, use_rope=use_rope, qk_rms_norm=qk_rms_norm,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
|
||||
if modulation:
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
|
||||
self.shift_table = nn.Parameter(torch.empty(1, 6 * channels, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, mod=None, rotary_emb=None, transformer_options=None):
|
||||
if self.modulation:
|
||||
if not self.share_mod:
|
||||
mod = self.adaLN_modulation(mod)
|
||||
mod = mod + comfy.model_management.cast_to(self.shift_table, mod.dtype, mod.device)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
||||
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
|
||||
x = torch.addcmul(x, self.attn(h, rope_emb=rotary_emb, transformer_options=transformer_options), gate_msa.unsqueeze(1))
|
||||
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
|
||||
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
|
||||
else:
|
||||
x = x + self.attn(self.norm1(x), rope_emb=rotary_emb, transformer_options=transformer_options)
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
emb = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
return self.mlp(emb.to(self.mlp[0].weight.dtype))
|
||||
|
||||
|
||||
class LatentSeqMMFlowModel(nn.Module):
|
||||
def __init__(self, image_model=None, q_token_length=8192, in_channels=16, model_channels=1024,
|
||||
cond_channels=1280, out_channels=16, num_blocks=24, num_refiner_blocks=2,
|
||||
num_heads=None, num_head_channels=64, cam_channels=5, cond2_channels=128,
|
||||
mlp_ratio=4, share_mod=True, qk_rms_norm=True,
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.q_token_length = q_token_length
|
||||
self.in_channels = in_channels
|
||||
self.cam_channels = cam_channels
|
||||
self.model_channels = model_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.cond2_channels = cond2_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.num_refiner_blocks = num_refiner_blocks
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.share_mod = share_mod
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
|
||||
factory_kwargs = dict(dtype=dtype, device=device)
|
||||
op_kwargs = dict(operations=operations, **factory_kwargs)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(model_channels, **op_kwargs)
|
||||
if share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, **factory_kwargs))
|
||||
|
||||
self.input_layer = operations.Linear(in_channels, model_channels, **factory_kwargs)
|
||||
self.cond_embedder = operations.Linear(cond_channels, model_channels, **factory_kwargs)
|
||||
self.cond_embedder2 = operations.Linear(cond2_channels, model_channels, **factory_kwargs) if cond2_channels is not None else None
|
||||
|
||||
# Fixed Sobol (low-discrepancy) 3D anchor positions for the latent tokens, used as positional encoding.
|
||||
# The embedder is parameter-free and the anchors are fixed, precompute once.
|
||||
sobol_seq = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123).draw(q_token_length)
|
||||
pos_emb = PcdAbsolutePositionEmbedder(model_channels)(sobol_seq.unsqueeze(0))
|
||||
self.register_buffer("pos_emb", pos_emb, persistent=False)
|
||||
|
||||
# RePo3DRotaryEmbedding layers for the refiner and main blocks
|
||||
repo_kwargs = dict(num_heads=self.num_heads, head_dim=num_head_channels, **op_kwargs)
|
||||
self.noise_repo_layers = nn.ModuleList(
|
||||
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)])
|
||||
self.context_repo_layers = nn.ModuleList(
|
||||
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)])
|
||||
self.repo_layers = nn.ModuleList(
|
||||
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_blocks)])
|
||||
|
||||
# Refiner blocks
|
||||
block_kwargs = dict(num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, use_rope=True, qk_rms_norm=self.qk_rms_norm, **op_kwargs)
|
||||
self.noise_refiner = nn.ModuleList(
|
||||
[UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_refiner_blocks)])
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[UnifiedTransformerBlock(model_channels, modulation=False, **block_kwargs) for _ in range(num_refiner_blocks)])
|
||||
|
||||
self.cam_refiner = MLP(self.cam_channels, model_channels, model_channels, **op_kwargs)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_blocks)])
|
||||
|
||||
self.shift_table = nn.Parameter(torch.empty(1, 2, model_channels, **factory_kwargs))
|
||||
self.out_layer = operations.Linear(model_channels, out_channels, **factory_kwargs)
|
||||
self.cam_out_layer = operations.Linear(model_channels, cam_channels, **factory_kwargs)
|
||||
|
||||
def forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, t, context, ref_latents, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
# x is the unpacked nested latent: [latent (B,8192,in_channels), camera (B,1,cam_channels)].
|
||||
# context == feature1.
|
||||
z, camera = x[0], x[1]
|
||||
feat1 = context
|
||||
|
||||
h_x = self.input_layer(z)
|
||||
h_cond = self.cond_embedder(feat1)
|
||||
if ref_latents is not None and self.cond_embedder2 is not None:
|
||||
# Flatten the Flux2 VAE latent (B,128,h,w) to a token sequence and front-pad to feat1's length
|
||||
# (the pad count = feat1's prefix tokens: DINOv3 cls + registers), then add to the context.
|
||||
feat2 = ref_latents[0].flatten(2).transpose(1, 2)
|
||||
feat2 = F.pad(feat2, (0, 0, feat1.shape[1] - feat2.shape[1], 0))
|
||||
h_cond = h_cond + self.cond_embedder2(feat2.to(h_cond.dtype))
|
||||
t_emb = self.t_embedder(t)
|
||||
t_mod = self.adaLN_modulation(t_emb) if self.share_mod else t_emb
|
||||
|
||||
h_x = h_x + self.pos_emb.to(z)
|
||||
|
||||
for i, block in enumerate(self.noise_refiner):
|
||||
h_x = block(h_x, mod=t_mod, rotary_emb=self.noise_repo_layers[i](h_x), transformer_options=transformer_options)
|
||||
|
||||
for i, block in enumerate(self.context_refiner):
|
||||
h_cond = block(h_cond, mod=None, rotary_emb=self.context_repo_layers[i](h_cond), transformer_options=transformer_options)
|
||||
|
||||
cam = camera.to(z)
|
||||
h_cam = self.cam_refiner(cam)
|
||||
h = torch.cat([h_x, h_cond, h_cam], dim=1)
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
h = block(h, mod=t_mod, rotary_emb=self.repo_layers[i](h), transformer_options=transformer_options)
|
||||
|
||||
h_x = F.layer_norm(h[:, :z.shape[1]].float(), h.shape[-1:]).to(z)
|
||||
h_cam = F.layer_norm(h[:, -cam.shape[1]:].float(), h.shape[-1:]).to(z)
|
||||
|
||||
shift, scale = (comfy.model_management.cast_to(self.shift_table, t_emb.dtype, t_emb.device) + t_emb.unsqueeze(1)).chunk(2, dim=1)
|
||||
scale = 1 + scale
|
||||
h_x = torch.addcmul(shift, h_x, scale)
|
||||
h_cam = torch.addcmul(shift, h_cam, scale)
|
||||
|
||||
return self.out_layer(h_x), self.cam_out_layer(h_cam)
|
||||
91
comfy/ldm/triposplat/preview.py
Normal file
91
comfy/ldm/triposplat/preview.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Live preview for TripoSplat: decode an x0 estimate into a coarse gaussian splat and render it with a perspective orbit camera.
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
_C0 = 0.28209479177387814
|
||||
_LATENT_TOKENS = 8192 # q_token_length
|
||||
_LATENT_CH = 16 # in_channels
|
||||
_OBJECT_TO_VIEWER = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], np.float32) # object frame -> viewer Y-up frame
|
||||
|
||||
|
||||
def _view_matrix(yaw_deg, pitch_deg):
|
||||
y, p = np.radians(yaw_deg), np.radians(pitch_deg)
|
||||
Ry = np.array([[np.cos(y), 0, np.sin(y)], [0, 1, 0], [-np.sin(y), 0, np.cos(y)]], np.float32)
|
||||
Rx = np.array([[1, 0, 0], [0, np.cos(p), -np.sin(p)], [0, np.sin(p), np.cos(p)]], np.float32)
|
||||
return Rx @ Ry
|
||||
|
||||
|
||||
def render_splat(xyz, rgb, scale, opacity=None, yaw=35.0, pitch=30.0, size=320, min_px=2, gain=1.0,
|
||||
max_px=9, min_opacity=0.0, fov=35.0, dist=2.2):
|
||||
# Project gaussian centers with a perspective camera and paint each as a filled disk whose screen
|
||||
# radius follows the gaussian's world-space scale, composited with a nearest-wins z-buffer.
|
||||
# gain scales the footprint (≈ std spanned), `min_px`/`max_px` clamp the on-screen radius.
|
||||
|
||||
pts = xyz.astype(np.float32) @ _OBJECT_TO_VIEWER.T
|
||||
v = pts @ _view_matrix(yaw, pitch).T
|
||||
zc = v[:, 2] + dist
|
||||
keep = zc > 1e-2
|
||||
if opacity is not None and min_opacity > 0.0: # culls gaussians with very low opacity
|
||||
keep = keep & (opacity > min_opacity)
|
||||
v, zc, scale = v[keep], zc[keep], scale[keep]
|
||||
col = (np.clip(rgb, 0, 1)[:, :3] * 255).astype(np.uint8)[keep]
|
||||
if v.shape[0] == 0:
|
||||
return Image.fromarray(np.zeros((size, size, 3), np.uint8))
|
||||
f = (size / 2) / np.tan(np.radians(fov) / 2)
|
||||
cx = size / 2 + f * v[:, 0] / zc
|
||||
cy = size / 2 + f * v[:, 1] / zc
|
||||
radius = np.clip(np.round(f * scale / zc * gain), min_px, max_px).astype(np.int32)
|
||||
|
||||
# Expand each splat to its disk pixels, bucketed by integer radius so it stays vectorized.
|
||||
px, py, pz, pc = [], [], [], []
|
||||
for r in range(int(radius.min()), int(radius.max()) + 1):
|
||||
m = radius == r
|
||||
if not m.any():
|
||||
continue
|
||||
dy, dx = np.mgrid[-r:r + 1, -r:r + 1]
|
||||
disk = (dx * dx + dy * dy) <= r * r
|
||||
ox, oy = dx[disk], dy[disk]
|
||||
px.append((cx[m, None] + ox).ravel())
|
||||
py.append((cy[m, None] + oy).ravel())
|
||||
pz.append(np.repeat(zc[m], ox.size))
|
||||
pc.append(np.repeat(col[m], ox.size, axis=0))
|
||||
px, py = np.concatenate(px), np.concatenate(py)
|
||||
pz, pc = np.concatenate(pz), np.concatenate(pc)
|
||||
xi = np.clip(px, 0, size - 1).astype(np.int64)
|
||||
yi = np.clip(py, 0, size - 1).astype(np.int64)
|
||||
|
||||
# Nearest-wins z-buffer: pack (quantized depth, source index), per-pixel min picks the closest
|
||||
# splat, then decode the winning index back to its color.
|
||||
pid = yi * size + xi
|
||||
q = np.clip((pz * 1024.0).astype(np.int64), 0, (1 << 20) - 1) # near = small
|
||||
key = (q << 32) | np.arange(pid.size, dtype=np.int64)
|
||||
buf = np.full(size * size, 1 << 62, np.int64)
|
||||
np.minimum.at(buf, pid, key)
|
||||
img = np.zeros((size * size, 3), np.uint8)
|
||||
hit = buf < (1 << 62)
|
||||
img[hit] = pc[buf[hit] & 0xFFFFFFFF]
|
||||
return Image.fromarray(img.reshape(size, size, 3))
|
||||
|
||||
|
||||
def _extract_latent(x0):
|
||||
# x0 from the sampler callback is the nested latent packed to (B, 1, TOKENS*CH + 1*5);
|
||||
# the plain single-latent case is (B, TOKENS, CH). Return the (B, TOKENS, CH) latent stream.
|
||||
if x0.ndim == 3 and x0.shape[1] == _LATENT_TOKENS and x0.shape[2] == _LATENT_CH:
|
||||
return x0
|
||||
flat = x0.reshape(x0.shape[0], -1)
|
||||
return flat[:, :_LATENT_TOKENS * _LATENT_CH].reshape(x0.shape[0], _LATENT_TOKENS, _LATENT_CH)
|
||||
|
||||
|
||||
def decode_x0_to_image(decoder, x0, cfg):
|
||||
# Decode x0 at a coarse octree level / few gaussians and render a preview image.
|
||||
latent = _extract_latent(x0)
|
||||
fsm = decoder.first_stage_model
|
||||
gaussian = fsm.decode(latent.to(decoder.device, decoder.vae_dtype),
|
||||
num_gaussians=cfg.get("gaussians", 16384), level=cfg.get("level", 5))[0]
|
||||
xyz = gaussian.get_xyz.float().cpu().numpy()
|
||||
rgb = gaussian._features_dc.float().cpu().numpy()[:, 0, :] * _C0 + 0.5
|
||||
scale = gaussian.get_scaling.float().cpu().numpy().max(axis=1) # per-splat world radius (largest axis)
|
||||
opacity = gaussian.get_opacity.float().cpu().numpy()[:, 0]
|
||||
return render_splat(xyz, rgb, scale, opacity=opacity, yaw=cfg.get("yaw", 35.0), pitch=cfg.get("pitch", 30.0),
|
||||
size=cfg.get("size", 320), min_px=1, gain=1.0, max_px=cfg.get("point_size", 3),
|
||||
min_opacity=0.01)
|
||||
382
comfy/ldm/triposplat/vae.py
Normal file
382
comfy/ldm/triposplat/vae.py
Normal file
@ -0,0 +1,382 @@
|
||||
# TripoSplat gaussian decoder ("VAE"): an octree probability decoder picks point coords, then an
|
||||
# elastic-gaussian decoder predicts per-point gaussian params. OctreeGaussianDecoder.decode() returns
|
||||
# a Gaussian. The octree sampler uses the global torch RNG (no generator) like upstream, so seed it for repeatable decodes.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
from .gaussian import build_gaussian_models
|
||||
from .model import MultiHeadRMSNorm, MLP, PcdAbsolutePositionEmbedder, attention
|
||||
|
||||
|
||||
# Quasi-random sampling utilities (pure functions, dtype/device-agnostic)
|
||||
|
||||
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
|
||||
|
||||
|
||||
def radical_inverse(base, n):
|
||||
val = 0
|
||||
inv_base = 1.0 / base
|
||||
inv_base_n = inv_base
|
||||
while n > 0:
|
||||
digit = n % base
|
||||
val += digit * inv_base_n
|
||||
n //= base
|
||||
inv_base_n *= inv_base
|
||||
return val
|
||||
|
||||
|
||||
def halton_sequence(dim, n):
|
||||
return [radical_inverse(PRIMES[i], n) for i in range(dim)]
|
||||
|
||||
|
||||
def hammersley_sequence(dim, n, num_samples):
|
||||
return [n / num_samples] + halton_sequence(dim - 1, n)
|
||||
|
||||
|
||||
def sample_probs(probs, counts, generator=None):
|
||||
# Systematic resampling: distribute counts[r] draws across the P bins of row r
|
||||
batch_shape = counts.shape
|
||||
R = counts.numel()
|
||||
P = probs.size(-1)
|
||||
device = probs.device
|
||||
probs = probs.reshape(R, P).to(torch.float32).clamp_min(0)
|
||||
counts = counts.reshape(R).to(device=device, dtype=torch.long)
|
||||
|
||||
row_sums = probs.sum(1, keepdim=True)
|
||||
probs = torch.where(row_sums == 0, probs.new_tensor(1.0 / P), probs / row_sums.clamp_min(1))
|
||||
cdf = probs.cumsum(dim=1).clamp(max=1.0 - 1e-12)
|
||||
|
||||
Nmax = int(counts.max())
|
||||
if Nmax == 0:
|
||||
return counts.new_zeros(*batch_shape, P)
|
||||
cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1)
|
||||
grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax)
|
||||
u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded)
|
||||
idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1)
|
||||
weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r]
|
||||
out = torch.zeros(R, P, dtype=torch.float32, device=device)
|
||||
out.scatter_add_(1, idx, weight)
|
||||
return out.to(torch.long).view(*batch_shape, P)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, channels, num_heads, ctx_channels=None, type="self", qkv_bias=True, qk_rms_norm=False,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert channels % num_heads == 0
|
||||
self.channels = channels
|
||||
self.head_dim = channels // num_heads
|
||||
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
||||
self.num_heads = num_heads
|
||||
self._type = type
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
if self._type == "self":
|
||||
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, dtype=dtype, device=device)
|
||||
if self.qk_rms_norm:
|
||||
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
||||
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
||||
self.to_out = operations.Linear(channels, channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, context=None):
|
||||
B, L, C = x.shape
|
||||
if self._type == "self":
|
||||
q, k, v = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1).unbind(dim=2)
|
||||
else:
|
||||
Lkv = context.shape[1]
|
||||
q = self.to_q(x).reshape(B, L, self.num_heads, -1)
|
||||
k, v = self.to_kv(context).reshape(B, Lkv, 2, self.num_heads, -1).unbind(dim=2)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k = self.k_rms_norm(k)
|
||||
h = attention(q, k, v)
|
||||
return self.to_out(h.reshape(B, L, -1))
|
||||
|
||||
|
||||
# Octree probability decoder
|
||||
|
||||
class LevelEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, max_period=1024,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.max_period = max_period
|
||||
|
||||
@staticmethod
|
||||
def level_embedding(t, dim, max_period=1024):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None] * 2 * torch.pi
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
emb = self.level_embedding(t, self.frequency_embedding_size, self.max_period)
|
||||
return self.mlp(emb.to(self.mlp[0].weight.dtype))
|
||||
|
||||
|
||||
class ModulatedTransformerCrossOnlyBlock(nn.Module):
|
||||
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, share_mod=False,
|
||||
qk_rms_norm_cross=True, qkv_bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads,
|
||||
type="cross", qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, mod, context):
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
|
||||
x = torch.addcmul(x, self.cross_attn(h, context), gate_msa.unsqueeze(1))
|
||||
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
|
||||
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
|
||||
return x
|
||||
|
||||
|
||||
class OctreeProbabilityFixedlenDecoder(nn.Module):
|
||||
# Cross-attention transformer over octree coords -> per-node 8-way child occupancy logits.
|
||||
def __init__(self, model_channels=1024, cond_channels=16, num_blocks=4, num_heads=16,
|
||||
num_head_channels=64, mlp_ratio=4.0, share_mod=True,
|
||||
qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.model_channels = model_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.share_mod = share_mod
|
||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
|
||||
self.l_embedder = LevelEmbedder(model_channels, dtype=dtype, device=device, operations=operations)
|
||||
if share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, dtype=dtype, device=device))
|
||||
if cond_channels is not None:
|
||||
self.blocks = nn.ModuleList([
|
||||
ModulatedTransformerCrossOnlyBlock(
|
||||
model_channels, ctx_channels=cond_channels, num_heads=self.num_heads,
|
||||
mlp_ratio=self.mlp_ratio, qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||
share_mod=self.share_mod, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
self.out_proj = operations.Linear(model_channels, 8, dtype=dtype, device=device)
|
||||
self.in_proj = operations.Linear(3, model_channels, dtype=dtype, device=device)
|
||||
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
|
||||
|
||||
def forward(self, x, l, cond):
|
||||
d = next(self.parameters()).dtype
|
||||
B, L, _ = x.shape
|
||||
h = self.in_proj(x.to(d)) + self.pos_embedder(x.reshape(-1, 3)).reshape(B, L, -1).to(d)
|
||||
h = self.input_layer(h)
|
||||
l_emb = self.l_embedder(l)
|
||||
if self.share_mod:
|
||||
l_emb = self.adaLN_modulation(l_emb)
|
||||
cond = cond.to(d)
|
||||
for block in self.blocks:
|
||||
h = block(h, l_emb, cond)
|
||||
h = F.layer_norm(h.float(), h.shape[-1:]).to(d)
|
||||
logits = self.out_proj(h)
|
||||
return {"logits": logits, "probs": torch.softmax(logits, dim=-1)}
|
||||
|
||||
@staticmethod
|
||||
def sample(model, cond, num_points, level, temperature=1.0, generator=None):
|
||||
B = cond.shape[0]
|
||||
device = cond.device
|
||||
child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]],
|
||||
dtype=torch.long, device=device)
|
||||
prev_coords_int = torch.zeros(B, 1, 3, dtype=torch.long, device=device)
|
||||
prev_counts = torch.full((B, 1), num_points, dtype=torch.long, device=device)
|
||||
prev_log_probs = torch.zeros(B, 1, dtype=torch.float32, device=device)
|
||||
batch_indices_range = torch.arange(B, device=device).unsqueeze(1)
|
||||
|
||||
for lv in range(1, level + 1):
|
||||
res_p = 1 << (lv - 1)
|
||||
res = 1 << lv
|
||||
parent_coords_norm = (prev_coords_int.to(torch.float32) + 0.5) / res_p
|
||||
res_tensor = torch.full((B,), res, dtype=torch.long, device=device)
|
||||
pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature
|
||||
pred_probs = torch.softmax(pred_logits, dim=-1)
|
||||
pred_log_probs = torch.log_softmax(pred_logits, dim=-1)
|
||||
sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2)
|
||||
pred_log_probs = pred_log_probs.flatten(1, 2)
|
||||
prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1)
|
||||
child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2)
|
||||
mask = sampled > 0
|
||||
max_valid = mask.sum(dim=1).max().item()
|
||||
scatter_indices = mask.cumsum(dim=1) - 1
|
||||
valid_scatter_indices = scatter_indices[mask]
|
||||
valid_batch_indices = batch_indices_range.expand_as(mask)[mask]
|
||||
next_prev_coords_int = torch.zeros(B, max_valid, 3, dtype=child_coords_int.dtype, device=device)
|
||||
next_prev_coords_int[valid_batch_indices, valid_scatter_indices] = child_coords_int[mask]
|
||||
next_prev_counts = torch.zeros(B, max_valid, dtype=sampled.dtype, device=device)
|
||||
next_prev_counts[valid_batch_indices, valid_scatter_indices] = sampled[mask]
|
||||
next_prev_log_probs = torch.zeros(B, max_valid, dtype=prev_log_probs.dtype, device=device)
|
||||
next_prev_log_probs[valid_batch_indices, valid_scatter_indices] = (prev_log_probs_expanded + pred_log_probs)[mask]
|
||||
prev_coords_int = next_prev_coords_int
|
||||
prev_counts = next_prev_counts
|
||||
prev_log_probs = next_prev_log_probs
|
||||
|
||||
res = 1 << level
|
||||
prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points)
|
||||
coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1)
|
||||
rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device)
|
||||
coords_norm = (coords_int.to(torch.float32) + rand) / res
|
||||
return {"points": coords_norm, "log_probs": prev_log_probs}
|
||||
|
||||
|
||||
# Elastic gaussian decoder
|
||||
|
||||
class TransformerCrossBlock(nn.Module):
|
||||
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0,
|
||||
qk_rms_norm=True, qk_rms_norm_cross=True, qkv_bias=True,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.self_attn = MultiHeadAttention(channels, num_heads=num_heads, type="self", qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm, dtype=dtype, device=device, operations=operations)
|
||||
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, type="cross",
|
||||
qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, context):
|
||||
x = x + self.self_attn(self.norm1(x))
|
||||
x = x + self.cross_attn(self.norm2(x), context)
|
||||
x = x + self.mlp(self.norm3(x))
|
||||
return x
|
||||
|
||||
|
||||
class ElasticGaussianFixedlenDecoder(nn.Module):
|
||||
# Cross-attention transformer over sampled octree points -> per-point gaussian params.
|
||||
def __init__(self, in_channels=3, model_channels=1024, cond_channels=16, num_blocks=16, num_heads=16,
|
||||
num_head_channels=64, mlp_ratio=4.0, *, representation_config=None,
|
||||
qk_rms_norm=True, qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.rep_config = representation_config or dict(
|
||||
lr=dict(_xyz=1.0, _features_dc=1.0, _opacity=1.0, _scaling=1.0, _rotation=0.1),
|
||||
perturb_offset=True, perturbe_size=1.5, offset_scale=0.05, num_gaussians=32,
|
||||
filter_kernel_size_3d=0.0009, scaling_bias=0.004, opacity_bias=0.1,
|
||||
scaling_activation="softplus",
|
||||
)
|
||||
self.out_channels = self._calc_layout()
|
||||
self.model_channels = model_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
|
||||
if cond_channels is not None:
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerCrossBlock(model_channels, ctx_channels=cond_channels,
|
||||
num_heads=self.num_heads, mlp_ratio=self.mlp_ratio,
|
||||
qk_rms_norm=qk_rms_norm, qk_rms_norm_cross=qk_rms_norm_cross,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
self.in_proj = operations.Linear(in_channels, model_channels, dtype=dtype, device=device)
|
||||
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
|
||||
self.out_proj = operations.Linear(model_channels, self.out_channels, dtype=dtype, device=device)
|
||||
self._build_perturbation()
|
||||
|
||||
def _calc_layout(self):
|
||||
ng = self.rep_config['num_gaussians']
|
||||
self.layout = {
|
||||
'_xyz': {'shape': (ng, 3), 'size': ng * 3},
|
||||
'_features_dc': {'shape': (ng, 1, 3), 'size': ng * 3},
|
||||
'_scaling': {'shape': (ng, 3), 'size': ng * 3},
|
||||
'_rotation': {'shape': (ng, 4), 'size': ng * 4},
|
||||
'_opacity': {'shape': (ng, 1), 'size': ng},
|
||||
}
|
||||
self.layout['_offset_scale'] = {'shape': (ng, 1), 'size': ng}
|
||||
start = 0
|
||||
for k, v in self.layout.items():
|
||||
v['range'] = (start, start + v['size'])
|
||||
start += v['size']
|
||||
return start
|
||||
|
||||
def _build_perturbation(self):
|
||||
ng = self.rep_config['num_gaussians']
|
||||
perturbation = torch.tensor([hammersley_sequence(3, i, ng) for i in range(ng)]).float()
|
||||
perturbation = torch.atanh((perturbation * 2 - 1) / self.rep_config['perturbe_size'])
|
||||
self.register_buffer('points_offset_perturbation', perturbation)
|
||||
base = torch.tensor(self.rep_config['offset_scale'])
|
||||
self.register_buffer('base_offset_scale', torch.log(torch.exp(base) - 1.0))
|
||||
|
||||
def _get_offset(self, h):
|
||||
B = h.shape[0]
|
||||
r = self.layout['_offset_scale']['range']
|
||||
_offset_scale = F.softplus(
|
||||
h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_offset_scale']['shape'])
|
||||
+ comfy.model_management.cast_to(self.base_offset_scale, h.dtype, h.device))
|
||||
|
||||
r = self.layout['_xyz']['range']
|
||||
offset = h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_xyz']['shape'])
|
||||
offset = offset * self.rep_config['lr']['_xyz']
|
||||
if self.rep_config['perturb_offset']:
|
||||
offset = offset + comfy.model_management.cast_to(self.points_offset_perturbation, offset.dtype, offset.device)
|
||||
offset = torch.tanh(offset) * 0.5 * self.rep_config['perturbe_size']
|
||||
offset = offset * _offset_scale
|
||||
return offset
|
||||
|
||||
def forward(self, x=None, cond=None):
|
||||
pcd = x["points"]
|
||||
d = next(self.parameters()).dtype
|
||||
B, L, _ = pcd.shape
|
||||
h = self.in_proj(pcd.to(d)) + self.pos_embedder(pcd.reshape(-1, 3)).reshape(B, L, -1).to(d)
|
||||
h = self.input_layer(h)
|
||||
cond = cond.to(d)
|
||||
for block in self.blocks:
|
||||
h = block(h, cond)
|
||||
h = F.layer_norm(h.float(), h.shape[-1:]).to(h.dtype)
|
||||
return {"features": self.out_proj(h)}
|
||||
|
||||
|
||||
# Combined octree gaussian decoder (comfy first-stage model)
|
||||
|
||||
class OctreeGaussianDecoder(nn.Module):
|
||||
_MAX_VOXEL_LEVEL = 8
|
||||
|
||||
def __init__(self, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if operations is None:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
self.octree = OctreeProbabilityFixedlenDecoder(dtype=dtype, device=device, operations=operations)
|
||||
self.gs = ElasticGaussianFixedlenDecoder(dtype=dtype, device=device, operations=operations)
|
||||
|
||||
@property
|
||||
def gaussians_per_point(self) -> int:
|
||||
return self.gs.rep_config['num_gaussians']
|
||||
|
||||
def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None):
|
||||
# level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews.
|
||||
# generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG.
|
||||
level = self._MAX_VOXEL_LEVEL if level is None else level
|
||||
num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point)
|
||||
points_pred = OctreeProbabilityFixedlenDecoder.sample(
|
||||
self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator,
|
||||
)
|
||||
pred = self.gs(x=points_pred, cond=latent)
|
||||
return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item
|
||||
@ -46,6 +46,7 @@ import comfy.ldm.wan.model_animate
|
||||
import comfy.ldm.wan.ar_model
|
||||
import comfy.ldm.wan.model_wandancer
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.triposplat.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.chroma_radiance.model
|
||||
@ -1806,6 +1807,24 @@ class Hunyuan3Dv2_1(BaseModel):
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class TripoSplat(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.triposplat.model.LatentSeqMMFlowModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None) # DINOv3 token sequence -> cross-attention context.
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
ref_latents = kwargs.get("reference_latents", None) # Flux2 VAE image latent -> additive second conditioning.
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = comfy.conds.CONDList(list(ref_latents))
|
||||
latent_shapes = kwargs.get("latent_shapes", None) # {latent, camera} nested latent
|
||||
if latent_shapes is not None:
|
||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||
return out
|
||||
|
||||
|
||||
class HiDream(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
|
||||
|
||||
@ -676,6 +676,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}cam_out_layer.weight'.format(key_prefix) in state_dict_keys and '{}repo_layers.0.final_map.weight'.format(key_prefix) in state_dict_keys: # TripoSplat
|
||||
return {"image_model": "triposplat"}
|
||||
|
||||
if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1
|
||||
return {"image_model": "hidream_o1"}
|
||||
|
||||
|
||||
11
comfy/sd.py
11
comfy/sd.py
@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.triposplat.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.cogvideo.vae
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
@ -894,6 +895,16 @@ class VAE:
|
||||
#Force cast it for --disable-dynamic-vram users until there is a true core fix.
|
||||
if not comfy.memory_management.aimdo_enabled:
|
||||
self.disable_offload = True
|
||||
elif "gs.base_offset_scale" in sd and "octree.out_proj.weight" in sd: # TripoSplat octree gaussian decoder
|
||||
self.first_stage_model = comfy.ldm.triposplat.vae.OctreeGaussianDecoder()
|
||||
self.latent_channels = 16
|
||||
self.latent_dim = 1
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
# The generic VAE.encode/decode path isn't used: VAEDecodeTripoSplat calls the gaussian
|
||||
# decoder directly (structured GaussianSplat objects, not a tensor and reserves VRAM itself from num_gaussians.
|
||||
def _no_generic_io(*args, **kwargs):
|
||||
raise RuntimeError("TripoSplat gaussian decoder: use the 'TripoSplat Decode' (VAEDecodeTripoSplat)")
|
||||
self.memory_used_encode = self.memory_used_decode = _no_generic_io
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
|
||||
@ -1538,6 +1538,30 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2mini
|
||||
|
||||
class TripoSplat(supported_models_base.BASE):
|
||||
# Image -> 3D gaussian splat flow denoiser
|
||||
unet_config = {
|
||||
"image_model": "triposplat",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 0.6
|
||||
|
||||
latent_format = latent_formats.TripoSplat
|
||||
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.TripoSplat(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class HiDream(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hidream",
|
||||
@ -2200,6 +2224,7 @@ models = [
|
||||
Hunyuan3Dv2mini,
|
||||
Hunyuan3Dv2,
|
||||
Hunyuan3Dv2_1,
|
||||
TripoSplat,
|
||||
HiDream,
|
||||
HiDreamO1,
|
||||
Chroma,
|
||||
|
||||
@ -968,7 +968,8 @@ class RenderSplat(IO.ComfyNode):
|
||||
bg = _hex_to_rgb(background)
|
||||
bg_imgs = None
|
||||
if bg_image is not None: # resize the plate(s) to the render size: (B,H,W,3)
|
||||
bi = comfy.utils.common_upscale(bg_image.movedim(-1, 1), width, height, "bicubic", "disabled")
|
||||
bi = bg_image[... , :3].movedim(-1, 1) # (B,3,H,W)
|
||||
bi = comfy.utils.common_upscale(bi, width, height, "bicubic", "disabled")
|
||||
bg_imgs = bi.movedim(1, -1).clamp(0, 1)
|
||||
n_frames = abs(int(frames)) or 1 # magnitude = frame count (0 -> single still)
|
||||
orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction
|
||||
|
||||
269
comfy_extras/nodes_triposplat.py
Normal file
269
comfy_extras/nodes_triposplat.py
Normal file
@ -0,0 +1,269 @@
|
||||
# TripoSplat nodes: image -> 3D gaussian splat
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.nested_tensor
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
|
||||
|
||||
_Q_TOKEN_LENGTH = 8192
|
||||
_LATENT_CHANNELS = 16
|
||||
_CAM_CHANNELS = 5
|
||||
_DINOV3_MEAN = [0.485, 0.456, 0.406]
|
||||
_DINOV3_STD = [0.229, 0.224, 0.225]
|
||||
_NUM_GAUSSIANS_MIN = 32768
|
||||
_NUM_GAUSSIANS_MAX = 1048576
|
||||
|
||||
|
||||
def _preprocess(image: torch.Tensor, mask: torch.Tensor, erode_radius: int, size: int) -> torch.Tensor:
|
||||
# Match original preprocessing:
|
||||
# resize min side to `size` -> erode alpha -> alpha bbox -> 1.2x square crop -> resize -> composite on black.
|
||||
rgb = image[..., :3].clamp(0, 1).movedim(-1, 0) # (3, H, W)
|
||||
alpha = mask.clamp(0, 1)[None] # (1, H, W)
|
||||
rgba = torch.cat([rgb, alpha], 0)[None] # (1, 4, H, W)
|
||||
|
||||
h, w = rgba.shape[-2:]
|
||||
s = size / min(w, h)
|
||||
rgba = comfy.utils.common_upscale(rgba, max(1, round(w * s)), max(1, round(h * s)), "lanczos", "disabled").clamp(0, 1)
|
||||
|
||||
a = rgba[:, 3:4]
|
||||
if erode_radius > 0:
|
||||
# min filter over a (2r+1) window == morphological erosion of the alpha matte.
|
||||
a = -F.max_pool2d(-a, 2 * erode_radius + 1, stride=1, padding=erode_radius)
|
||||
rgba = torch.cat([rgba[:, :3], a], 1)
|
||||
|
||||
ys, xs = torch.nonzero(a[0, 0] > 0, as_tuple=True)
|
||||
if xs.numel() == 0:
|
||||
raise ValueError("TripoSplatPreprocessImage: mask is empty (no foreground pixels).")
|
||||
x0, x1 = int(xs.min()), int(xs.max())
|
||||
y0, y1 = int(ys.min()), int(ys.max())
|
||||
cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
|
||||
half = max(x1 - x0, y1 - y0) / 2 * 1.2
|
||||
left, upper, right, lower = int(cx - half), int(cy - half), int(cx + half), int(cy + half)
|
||||
|
||||
H, W = rgba.shape[-2:]
|
||||
crop = rgba.new_zeros((1, 4, lower - upper, right - left)) # out-of-bounds stays 0, matching PIL.crop
|
||||
sx0, sy0, sx1, sy1 = max(left, 0), max(upper, 0), min(right, W), min(lower, H)
|
||||
if sx1 > sx0 and sy1 > sy0:
|
||||
crop[:, :, sy0 - upper:sy1 - upper, sx0 - left:sx1 - left] = rgba[:, :, sy0:sy1, sx0:sx1]
|
||||
|
||||
crop = comfy.utils.common_upscale(crop, size, size, "lanczos", "disabled").clamp(0, 1)
|
||||
out = (crop[:, :3] * crop[:, 3:4])[0].movedim(0, -1) # composite over black == rgb * alpha
|
||||
return out.unsqueeze(0) # (1, 1024, 1024, 3)
|
||||
|
||||
|
||||
class TripoSplatPreprocessImage(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TripoSplatPreprocessImage",
|
||||
display_name="TripoSplat Preprocess Image",
|
||||
category="3d/conditioning",
|
||||
description="Crop center each image to a square canvas on a black background and add padding.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Mask.Input("mask"),
|
||||
IO.Int.Input("erode_radius", default=1, min=0, max=16,
|
||||
tooltip="Erode the alpha matte by this pixel radius before cropping (avoids border bleed)."),
|
||||
IO.Int.Input("size", default=1024, min=256, max=4096, step=16,
|
||||
tooltip="Square image size. The model is trained at 1024; other sizes run but are off-distribution."),
|
||||
],
|
||||
outputs=[IO.Image.Output(display_name="image")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, mask, erode_radius, size) -> IO.NodeOutput:
|
||||
size = max(16, (int(size) // 16) * 16) # DINOv3 patch / Flux2 VAE stride is 16
|
||||
if mask.shape[0] != image.shape[0]:
|
||||
mask = comfy.utils.repeat_to_batch_size(mask, image.shape[0])
|
||||
if tuple(mask.shape[1:]) != tuple(image.shape[1:3]):
|
||||
mask = F.interpolate(mask[:, None].float(), size=tuple(image.shape[1:3]), mode="bilinear", align_corners=False)[:, 0]
|
||||
prepared = torch.cat([_preprocess(image[i], mask[i], erode_radius, size) for i in range(image.shape[0])], dim=0)
|
||||
return IO.NodeOutput(prepared)
|
||||
|
||||
|
||||
class TripoSplatConditioning(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TripoSplatConditioning",
|
||||
display_name="TripoSplat Conditioning",
|
||||
category="3d/conditioning",
|
||||
description="Encode the image with DINOv3 and the Flux2 VAE into TripoSplat positive/negative "
|
||||
"conditioning, and create the fixed size noise target (latent + camera) for the KSampler",
|
||||
inputs=[
|
||||
IO.ClipVision.Input("clip_vision", tooltip="DINOv3 ViT-H/16+ image encoder"),
|
||||
IO.Vae.Input("vae", tooltip="Flux2 VAE"),
|
||||
IO.Image.Input("image"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
IO.Conditioning.Output(display_name="negative"),
|
||||
IO.Latent.Output(display_name="latent", tooltip="The fixed size noise target (latent +camera)."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip_vision, vae, image) -> IO.NodeOutput:
|
||||
# feature1: DINOv3 token sequence (cls + registers + patches), ImageNet-normalized, with a final non-affine layer norm on top
|
||||
comfy.model_management.load_model_gpu(clip_vision.patcher)
|
||||
device = clip_vision.load_device
|
||||
model_dtype = next(clip_vision.model.parameters()).dtype
|
||||
img = image.movedim(-1, 1).to(device) # (B,3,H,W) in [0,1]
|
||||
mean = torch.tensor(_DINOV3_MEAN, device=device).view(1, 3, 1, 1)
|
||||
std = torch.tensor(_DINOV3_STD, device=device).view(1, 3, 1, 1)
|
||||
img = (img - mean) / std
|
||||
seq = clip_vision.model(pixel_values=img.to(model_dtype))[0]
|
||||
feature1 = F.layer_norm(seq.float(), seq.shape[-1:]).to(comfy.model_management.intermediate_device())
|
||||
|
||||
# Second conditioning: the Flux2 VAE latent of the image, carried as a standard reference_latents entry
|
||||
ref = vae.encode(image).to(comfy.model_management.intermediate_device()) # (B, 128, H, W)
|
||||
b = ref.shape[0]
|
||||
|
||||
positive = [[feature1, {"reference_latents": [ref]}]]
|
||||
negative = [[torch.zeros_like(feature1), {"reference_latents": [torch.zeros_like(ref)]}]]
|
||||
|
||||
# Fixed noise target: the latent is a constant-shape (8192, 16) shape-code + a (1, 5) camera token
|
||||
dev = comfy.model_management.intermediate_device()
|
||||
latent_seq = torch.zeros([b, _Q_TOKEN_LENGTH, _LATENT_CHANNELS], device=dev)
|
||||
camera = torch.zeros([b, 1, _CAM_CHANNELS], device=dev)
|
||||
samples = comfy.nested_tensor.NestedTensor((latent_seq, camera))
|
||||
return IO.NodeOutput(positive, negative, {"samples": samples})
|
||||
|
||||
|
||||
class VAEDecodeTripoSplat(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeTripoSplat",
|
||||
display_name="TripoSplat Decode",
|
||||
category="3d/latent",
|
||||
description="Decode the sampled TripoSplat latent into a 3D gaussian splat. "
|
||||
"Modify the number of gaussians to vary the density.",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"),
|
||||
IO.Int.Input("num_gaussians", default=262144, min=_NUM_GAUSSIANS_MIN, max=_NUM_GAUSSIANS_MAX, step=32,
|
||||
tooltip="Number of gaussians to produce (rounded to a multiple of 32). "
|
||||
"262144 matches the octree's point density; higher oversamples the same points "
|
||||
"(denser, but no new detail) and costs proportionally more VRAM/time."),
|
||||
IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff,
|
||||
tooltip="Seeds the octree point sampler (global RNG) for deterministic decodes."),
|
||||
],
|
||||
outputs=[IO.Splat.Output(display_name="splat")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae, num_gaussians, seed) -> IO.NodeOutput:
|
||||
s = samples["samples"]
|
||||
latent = s.unbind()[0] if getattr(s, "is_nested", False) else s # take the latent stream, drop camera
|
||||
|
||||
decoder = vae.first_stage_model
|
||||
gpp = decoder.gaussians_per_point
|
||||
n = max(_NUM_GAUSSIANS_MIN, min(_NUM_GAUSSIANS_MAX, int(num_gaussians)))
|
||||
if n % gpp != 0:
|
||||
n = round(n / gpp) * gpp
|
||||
|
||||
dtype_size = comfy.model_management.dtype_size(vae.vae_dtype)
|
||||
hidden = decoder.gs.model_channels
|
||||
cond_tokens = latent.shape[1]
|
||||
memory_required = (cond_tokens * 4 + (n // gpp) * 10) * hidden * dtype_size
|
||||
comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required)
|
||||
latent = latent.to(device=vae.device, dtype=vae.vae_dtype)
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
parts = [g.render_tensors() for g in decoder.decode(latent, num_gaussians=n, generator=generator)]
|
||||
positions, scales, rotations, opacities, sh = (torch.stack(t) for t in zip(*parts))
|
||||
return IO.NodeOutput(Types.SPLAT(positions, scales, rotations, opacities, sh))
|
||||
|
||||
|
||||
class TripoSplatSamplingPreview(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TripoSplatSamplingPreview",
|
||||
display_name="TripoSplat Sampling Preview",
|
||||
category="3d/latent",
|
||||
description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded "
|
||||
"gaussian splat preview at each step.",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"),
|
||||
IO.Int.Input("octree_level", default=5, min=2, max=8, advanced=True,
|
||||
tooltip="Octree depth for the preview decode (lower = cheaper/coarser)."),
|
||||
IO.Int.Input("num_gaussians", default=16384, min=1024, max=262144, step=32,
|
||||
tooltip="Number of gaussians to produce for the preview (rounded to a multiple of 32)."),
|
||||
IO.Float.Input("yaw", default=90.0, min=-360.0, max=360.0, step=1.0, tooltip="Preview camera yaw in degrees.", advanced=True,),
|
||||
IO.Float.Input("pitch", default=15.0, min=-89.0, max=89.0, step=1.0, tooltip="Preview camera pitch in degrees.", advanced=True,),
|
||||
IO.Int.Input("point_size", default=3, min=1, max=16,
|
||||
tooltip="Maximum splat radius in pixels. Each gaussian is sized from its scale and capped here; "
|
||||
"lower = finer/pointier, higher = chunkier."),
|
||||
],
|
||||
outputs=[IO.Model.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, vae, octree_level, num_gaussians, yaw, pitch, point_size) -> IO.NodeOutput:
|
||||
from comfy.ldm.triposplat.preview import decode_x0_to_image
|
||||
cfg = {"gaussians": num_gaussians, "level": octree_level, "yaw": yaw, "pitch": pitch,
|
||||
"point_size": point_size}
|
||||
|
||||
fsm = vae.first_stage_model
|
||||
cond_tokens = model.model.diffusion_model.q_token_length
|
||||
memory_required = (cond_tokens * 4 + (num_gaussians // fsm.gaussians_per_point) * 10) * fsm.gs.model_channels * comfy.model_management.dtype_size(vae.vae_dtype)
|
||||
|
||||
# Live preview via WrappersMP.OUTER_SAMPLE + ProgressBar
|
||||
# The wrapper augments the sampler's own callback to decode x0 -> gaussian splat -> preview image each step
|
||||
def outer_sample_wrapper(executor, *args, **kwargs):
|
||||
args = list(args)
|
||||
cb_idx = 5 # outer_sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
orig_cb = args[cb_idx] if len(args) > cb_idx else kwargs.get("callback")
|
||||
state = {"ok": True, "pbar": None, "loaded": False}
|
||||
|
||||
def callback(step, x0, x, total_steps):
|
||||
if orig_cb is not None:
|
||||
orig_cb(step, x0, x, total_steps)
|
||||
if not state["ok"]:
|
||||
return
|
||||
try:
|
||||
if not state["loaded"]:
|
||||
comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required)
|
||||
state["loaded"] = True
|
||||
img = decode_x0_to_image(vae, x0, cfg)
|
||||
if state["pbar"] is None:
|
||||
state["pbar"] = comfy.utils.ProgressBar(total_steps)
|
||||
state["pbar"].update_absolute(step + 1, total_steps, ("JPEG", img, 512))
|
||||
except Exception as e:
|
||||
logging.warning("TripoSplatSamplingPreview: preview failed, disabling ({})".format(e))
|
||||
state["ok"] = False
|
||||
|
||||
if len(args) > cb_idx:
|
||||
args[cb_idx] = callback
|
||||
else:
|
||||
kwargs["callback"] = callback
|
||||
return executor(*args, **kwargs)
|
||||
|
||||
m = model.clone()
|
||||
m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "triposplat_sampling_preview", outer_sample_wrapper)
|
||||
return IO.NodeOutput(m)
|
||||
|
||||
|
||||
class TripoSplatExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
TripoSplatPreprocessImage,
|
||||
TripoSplatConditioning,
|
||||
VAEDecodeTripoSplat,
|
||||
TripoSplatSamplingPreview,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> TripoSplatExtension:
|
||||
return TripoSplatExtension()
|
||||
Loading…
Reference in New Issue
Block a user