ComfyUI/comfy/ldm/triposplat/vae.py
2026-06-01 07:01:50 -07:00

383 lines
20 KiB
Python

# TripoSplat gaussian decoder ("VAE"): an octree probability decoder picks point coords, then an
# elastic-gaussian decoder predicts per-point gaussian params. OctreeGaussianDecoder.decode() returns
# a Gaussian. The octree sampler uses the global torch RNG (no generator) like upstream, so seed it for repeatable decodes.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.model_management
import comfy.ops
from .gaussian import build_gaussian_models
from .model import MultiHeadRMSNorm, MLP, PcdAbsolutePositionEmbedder, attention
# Quasi-random sampling utilities (pure functions, dtype/device-agnostic)
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
def radical_inverse(base, n):
val = 0
inv_base = 1.0 / base
inv_base_n = inv_base
while n > 0:
digit = n % base
val += digit * inv_base_n
n //= base
inv_base_n *= inv_base
return val
def halton_sequence(dim, n):
return [radical_inverse(PRIMES[i], n) for i in range(dim)]
def hammersley_sequence(dim, n, num_samples):
return [n / num_samples] + halton_sequence(dim - 1, n)
def sample_probs(probs, counts, generator=None):
# Systematic resampling: distribute counts[r] draws across the P bins of row r
batch_shape = counts.shape
R = counts.numel()
P = probs.size(-1)
device = probs.device
probs = probs.reshape(R, P).to(torch.float32).clamp_min(0)
counts = counts.reshape(R).to(device=device, dtype=torch.long)
row_sums = probs.sum(1, keepdim=True)
probs = torch.where(row_sums == 0, probs.new_tensor(1.0 / P), probs / row_sums.clamp_min(1))
cdf = probs.cumsum(dim=1).clamp(max=1.0 - 1e-12)
Nmax = int(counts.max())
if Nmax == 0:
return counts.new_zeros(*batch_shape, P)
cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1)
grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax)
u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded)
idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1)
weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r]
out = torch.zeros(R, P, dtype=torch.float32, device=device)
out.scatter_add_(1, idx, weight)
return out.to(torch.long).view(*batch_shape, P)
class MultiHeadAttention(nn.Module):
def __init__(self, channels, num_heads, ctx_channels=None, type="self", qkv_bias=True, qk_rms_norm=False,
dtype=None, device=None, operations=None):
super().__init__()
assert channels % num_heads == 0
self.channels = channels
self.head_dim = channels // num_heads
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
self.num_heads = num_heads
self._type = type
self.qk_rms_norm = qk_rms_norm
if self._type == "self":
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
else:
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, dtype=dtype, device=device)
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, dtype=dtype, device=device)
if self.qk_rms_norm:
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.to_out = operations.Linear(channels, channels, dtype=dtype, device=device)
def forward(self, x, context=None):
B, L, C = x.shape
if self._type == "self":
q, k, v = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1).unbind(dim=2)
else:
Lkv = context.shape[1]
q = self.to_q(x).reshape(B, L, self.num_heads, -1)
k, v = self.to_kv(context).reshape(B, Lkv, 2, self.num_heads, -1).unbind(dim=2)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k = self.k_rms_norm(k)
h = attention(q, k, v)
return self.to_out(h.reshape(B, L, -1))
# Octree probability decoder
class LevelEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, max_period=1024,
dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
@staticmethod
def level_embedding(t, dim, max_period=1024):
half = dim // 2
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None] * 2 * torch.pi
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
emb = self.level_embedding(t, self.frequency_embedding_size, self.max_period)
return self.mlp(emb.to(self.mlp[0].weight.dtype))
class ModulatedTransformerCrossOnlyBlock(nn.Module):
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, share_mod=False,
qk_rms_norm_cross=True, qkv_bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.share_mod = share_mod
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads,
type="cross", qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
def forward(self, x, mod, context):
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
x = torch.addcmul(x, self.cross_attn(h, context), gate_msa.unsqueeze(1))
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
return x
class OctreeProbabilityFixedlenDecoder(nn.Module):
# Cross-attention transformer over octree coords -> per-node 8-way child occupancy logits.
def __init__(self, model_channels=1024, cond_channels=16, num_blocks=4, num_heads=16,
num_head_channels=64, mlp_ratio=4.0, share_mod=True,
qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
super().__init__()
self.model_channels = model_channels
self.cond_channels = cond_channels
self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.share_mod = share_mod
self.qk_rms_norm_cross = qk_rms_norm_cross
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
self.l_embedder = LevelEmbedder(model_channels, dtype=dtype, device=device, operations=operations)
if share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, dtype=dtype, device=device))
if cond_channels is not None:
self.blocks = nn.ModuleList([
ModulatedTransformerCrossOnlyBlock(
model_channels, ctx_channels=cond_channels, num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio, qk_rms_norm_cross=self.qk_rms_norm_cross,
share_mod=self.share_mod, dtype=dtype, device=device, operations=operations)
for _ in range(num_blocks)
])
self.out_proj = operations.Linear(model_channels, 8, dtype=dtype, device=device)
self.in_proj = operations.Linear(3, model_channels, dtype=dtype, device=device)
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
def forward(self, x, l, cond):
d = next(self.parameters()).dtype
B, L, _ = x.shape
h = self.in_proj(x.to(d)) + self.pos_embedder(x.reshape(-1, 3)).reshape(B, L, -1).to(d)
h = self.input_layer(h)
l_emb = self.l_embedder(l)
if self.share_mod:
l_emb = self.adaLN_modulation(l_emb)
cond = cond.to(d)
for block in self.blocks:
h = block(h, l_emb, cond)
h = F.layer_norm(h.float(), h.shape[-1:]).to(d)
logits = self.out_proj(h)
return {"logits": logits, "probs": torch.softmax(logits, dim=-1)}
@staticmethod
def sample(model, cond, num_points, level, temperature=1.0, generator=None):
B = cond.shape[0]
device = cond.device
child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]],
dtype=torch.long, device=device)
prev_coords_int = torch.zeros(B, 1, 3, dtype=torch.long, device=device)
prev_counts = torch.full((B, 1), num_points, dtype=torch.long, device=device)
prev_log_probs = torch.zeros(B, 1, dtype=torch.float32, device=device)
batch_indices_range = torch.arange(B, device=device).unsqueeze(1)
for lv in range(1, level + 1):
res_p = 1 << (lv - 1)
res = 1 << lv
parent_coords_norm = (prev_coords_int.to(torch.float32) + 0.5) / res_p
res_tensor = torch.full((B,), res, dtype=torch.long, device=device)
pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature
pred_probs = torch.softmax(pred_logits, dim=-1)
pred_log_probs = torch.log_softmax(pred_logits, dim=-1)
sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2)
pred_log_probs = pred_log_probs.flatten(1, 2)
prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1)
child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2)
mask = sampled > 0
max_valid = mask.sum(dim=1).max().item()
scatter_indices = mask.cumsum(dim=1) - 1
valid_scatter_indices = scatter_indices[mask]
valid_batch_indices = batch_indices_range.expand_as(mask)[mask]
next_prev_coords_int = torch.zeros(B, max_valid, 3, dtype=child_coords_int.dtype, device=device)
next_prev_coords_int[valid_batch_indices, valid_scatter_indices] = child_coords_int[mask]
next_prev_counts = torch.zeros(B, max_valid, dtype=sampled.dtype, device=device)
next_prev_counts[valid_batch_indices, valid_scatter_indices] = sampled[mask]
next_prev_log_probs = torch.zeros(B, max_valid, dtype=prev_log_probs.dtype, device=device)
next_prev_log_probs[valid_batch_indices, valid_scatter_indices] = (prev_log_probs_expanded + pred_log_probs)[mask]
prev_coords_int = next_prev_coords_int
prev_counts = next_prev_counts
prev_log_probs = next_prev_log_probs
res = 1 << level
prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points)
coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1)
rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device)
coords_norm = (coords_int.to(torch.float32) + rand) / res
return {"points": coords_norm, "log_probs": prev_log_probs}
# Elastic gaussian decoder
class TransformerCrossBlock(nn.Module):
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0,
qk_rms_norm=True, qk_rms_norm_cross=True, qkv_bias=True,
dtype=None, device=None, operations=None):
super().__init__()
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.self_attn = MultiHeadAttention(channels, num_heads=num_heads, type="self", qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm, dtype=dtype, device=device, operations=operations)
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, type="cross",
qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
def forward(self, x, context):
x = x + self.self_attn(self.norm1(x))
x = x + self.cross_attn(self.norm2(x), context)
x = x + self.mlp(self.norm3(x))
return x
class ElasticGaussianFixedlenDecoder(nn.Module):
# Cross-attention transformer over sampled octree points -> per-point gaussian params.
def __init__(self, in_channels=3, model_channels=1024, cond_channels=16, num_blocks=16, num_heads=16,
num_head_channels=64, mlp_ratio=4.0, *, representation_config=None,
qk_rms_norm=True, qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
super().__init__()
self.rep_config = representation_config or dict(
lr=dict(_xyz=1.0, _features_dc=1.0, _opacity=1.0, _scaling=1.0, _rotation=0.1),
perturb_offset=True, perturbe_size=1.5, offset_scale=0.05, num_gaussians=32,
filter_kernel_size_3d=0.0009, scaling_bias=0.004, opacity_bias=0.1,
scaling_activation="softplus",
)
self.out_channels = self._calc_layout()
self.model_channels = model_channels
self.cond_channels = cond_channels
self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
if cond_channels is not None:
self.blocks = nn.ModuleList([
TransformerCrossBlock(model_channels, ctx_channels=cond_channels,
num_heads=self.num_heads, mlp_ratio=self.mlp_ratio,
qk_rms_norm=qk_rms_norm, qk_rms_norm_cross=qk_rms_norm_cross,
dtype=dtype, device=device, operations=operations)
for _ in range(num_blocks)
])
self.in_proj = operations.Linear(in_channels, model_channels, dtype=dtype, device=device)
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
self.out_proj = operations.Linear(model_channels, self.out_channels, dtype=dtype, device=device)
self._build_perturbation()
def _calc_layout(self):
ng = self.rep_config['num_gaussians']
self.layout = {
'_xyz': {'shape': (ng, 3), 'size': ng * 3},
'_features_dc': {'shape': (ng, 1, 3), 'size': ng * 3},
'_scaling': {'shape': (ng, 3), 'size': ng * 3},
'_rotation': {'shape': (ng, 4), 'size': ng * 4},
'_opacity': {'shape': (ng, 1), 'size': ng},
}
self.layout['_offset_scale'] = {'shape': (ng, 1), 'size': ng}
start = 0
for k, v in self.layout.items():
v['range'] = (start, start + v['size'])
start += v['size']
return start
def _build_perturbation(self):
ng = self.rep_config['num_gaussians']
perturbation = torch.tensor([hammersley_sequence(3, i, ng) for i in range(ng)]).float()
perturbation = torch.atanh((perturbation * 2 - 1) / self.rep_config['perturbe_size'])
self.register_buffer('points_offset_perturbation', perturbation)
base = torch.tensor(self.rep_config['offset_scale'])
self.register_buffer('base_offset_scale', torch.log(torch.exp(base) - 1.0))
def _get_offset(self, h):
B = h.shape[0]
r = self.layout['_offset_scale']['range']
_offset_scale = F.softplus(
h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_offset_scale']['shape'])
+ comfy.model_management.cast_to(self.base_offset_scale, h.dtype, h.device))
r = self.layout['_xyz']['range']
offset = h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_xyz']['shape'])
offset = offset * self.rep_config['lr']['_xyz']
if self.rep_config['perturb_offset']:
offset = offset + comfy.model_management.cast_to(self.points_offset_perturbation, offset.dtype, offset.device)
offset = torch.tanh(offset) * 0.5 * self.rep_config['perturbe_size']
offset = offset * _offset_scale
return offset
def forward(self, x=None, cond=None):
pcd = x["points"]
d = next(self.parameters()).dtype
B, L, _ = pcd.shape
h = self.in_proj(pcd.to(d)) + self.pos_embedder(pcd.reshape(-1, 3)).reshape(B, L, -1).to(d)
h = self.input_layer(h)
cond = cond.to(d)
for block in self.blocks:
h = block(h, cond)
h = F.layer_norm(h.float(), h.shape[-1:]).to(h.dtype)
return {"features": self.out_proj(h)}
# Combined octree gaussian decoder (comfy first-stage model)
class OctreeGaussianDecoder(nn.Module):
_MAX_VOXEL_LEVEL = 8
def __init__(self, dtype=None, device=None, operations=None):
super().__init__()
if operations is None:
operations = comfy.ops.disable_weight_init
self.octree = OctreeProbabilityFixedlenDecoder(dtype=dtype, device=device, operations=operations)
self.gs = ElasticGaussianFixedlenDecoder(dtype=dtype, device=device, operations=operations)
@property
def gaussians_per_point(self) -> int:
return self.gs.rep_config['num_gaussians']
def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None):
# level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews.
# generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG.
level = self._MAX_VOXEL_LEVEL if level is None else level
num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point)
points_pred = OctreeProbabilityFixedlenDecoder.sample(
self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator,
)
pred = self.gs(x=points_pred, cond=latent)
return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item