mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-02 04:17:33 +08:00
383 lines
20 KiB
Python
383 lines
20 KiB
Python
# TripoSplat gaussian decoder ("VAE"): an octree probability decoder picks point coords, then an
|
|
# elastic-gaussian decoder predicts per-point gaussian params. OctreeGaussianDecoder.decode() returns
|
|
# a Gaussian. The octree sampler uses the global torch RNG (no generator) like upstream, so seed it for repeatable decodes.
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import comfy.model_management
|
|
import comfy.ops
|
|
from .gaussian import build_gaussian_models
|
|
from .model import MultiHeadRMSNorm, MLP, PcdAbsolutePositionEmbedder, attention
|
|
|
|
|
|
# Quasi-random sampling utilities (pure functions, dtype/device-agnostic)
|
|
|
|
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
|
|
|
|
|
|
def radical_inverse(base, n):
|
|
val = 0
|
|
inv_base = 1.0 / base
|
|
inv_base_n = inv_base
|
|
while n > 0:
|
|
digit = n % base
|
|
val += digit * inv_base_n
|
|
n //= base
|
|
inv_base_n *= inv_base
|
|
return val
|
|
|
|
|
|
def halton_sequence(dim, n):
|
|
return [radical_inverse(PRIMES[i], n) for i in range(dim)]
|
|
|
|
|
|
def hammersley_sequence(dim, n, num_samples):
|
|
return [n / num_samples] + halton_sequence(dim - 1, n)
|
|
|
|
|
|
def sample_probs(probs, counts, generator=None):
|
|
# Systematic resampling: distribute counts[r] draws across the P bins of row r
|
|
batch_shape = counts.shape
|
|
R = counts.numel()
|
|
P = probs.size(-1)
|
|
device = probs.device
|
|
probs = probs.reshape(R, P).to(torch.float32).clamp_min(0)
|
|
counts = counts.reshape(R).to(device=device, dtype=torch.long)
|
|
|
|
row_sums = probs.sum(1, keepdim=True)
|
|
probs = torch.where(row_sums == 0, probs.new_tensor(1.0 / P), probs / row_sums.clamp_min(1))
|
|
cdf = probs.cumsum(dim=1).clamp(max=1.0 - 1e-12)
|
|
|
|
Nmax = int(counts.max())
|
|
if Nmax == 0:
|
|
return counts.new_zeros(*batch_shape, P)
|
|
cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1)
|
|
grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax)
|
|
u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded)
|
|
idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1)
|
|
weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r]
|
|
out = torch.zeros(R, P, dtype=torch.float32, device=device)
|
|
out.scatter_add_(1, idx, weight)
|
|
return out.to(torch.long).view(*batch_shape, P)
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
def __init__(self, channels, num_heads, ctx_channels=None, type="self", qkv_bias=True, qk_rms_norm=False,
|
|
dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
assert channels % num_heads == 0
|
|
self.channels = channels
|
|
self.head_dim = channels // num_heads
|
|
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
|
self.num_heads = num_heads
|
|
self._type = type
|
|
self.qk_rms_norm = qk_rms_norm
|
|
if self._type == "self":
|
|
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
|
|
else:
|
|
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, dtype=dtype, device=device)
|
|
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, dtype=dtype, device=device)
|
|
if self.qk_rms_norm:
|
|
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
|
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
|
self.to_out = operations.Linear(channels, channels, dtype=dtype, device=device)
|
|
|
|
def forward(self, x, context=None):
|
|
B, L, C = x.shape
|
|
if self._type == "self":
|
|
q, k, v = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1).unbind(dim=2)
|
|
else:
|
|
Lkv = context.shape[1]
|
|
q = self.to_q(x).reshape(B, L, self.num_heads, -1)
|
|
k, v = self.to_kv(context).reshape(B, Lkv, 2, self.num_heads, -1).unbind(dim=2)
|
|
if self.qk_rms_norm:
|
|
q = self.q_rms_norm(q)
|
|
k = self.k_rms_norm(k)
|
|
h = attention(q, k, v)
|
|
return self.to_out(h.reshape(B, L, -1))
|
|
|
|
|
|
# Octree probability decoder
|
|
|
|
class LevelEmbedder(nn.Module):
|
|
def __init__(self, hidden_size, frequency_embedding_size=256, max_period=1024,
|
|
dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.mlp = nn.Sequential(
|
|
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
|
nn.SiLU(),
|
|
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
|
)
|
|
self.frequency_embedding_size = frequency_embedding_size
|
|
self.max_period = max_period
|
|
|
|
@staticmethod
|
|
def level_embedding(t, dim, max_period=1024):
|
|
half = dim // 2
|
|
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
|
args = t[:, None].float() * freqs[None] * 2 * torch.pi
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
if dim % 2:
|
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
|
return embedding
|
|
|
|
def forward(self, t):
|
|
emb = self.level_embedding(t, self.frequency_embedding_size, self.max_period)
|
|
return self.mlp(emb.to(self.mlp[0].weight.dtype))
|
|
|
|
|
|
class ModulatedTransformerCrossOnlyBlock(nn.Module):
|
|
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, share_mod=False,
|
|
qk_rms_norm_cross=True, qkv_bias=True, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.share_mod = share_mod
|
|
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
self.norm2 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads,
|
|
type="cross", qkv_bias=qkv_bias,
|
|
qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
|
|
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
|
|
if not share_mod:
|
|
self.adaLN_modulation = nn.Sequential(
|
|
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
|
|
|
|
def forward(self, x, mod, context):
|
|
if self.share_mod:
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
|
else:
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
|
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
|
|
x = torch.addcmul(x, self.cross_attn(h, context), gate_msa.unsqueeze(1))
|
|
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
|
|
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
|
|
return x
|
|
|
|
|
|
class OctreeProbabilityFixedlenDecoder(nn.Module):
|
|
# Cross-attention transformer over octree coords -> per-node 8-way child occupancy logits.
|
|
def __init__(self, model_channels=1024, cond_channels=16, num_blocks=4, num_heads=16,
|
|
num_head_channels=64, mlp_ratio=4.0, share_mod=True,
|
|
qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.model_channels = model_channels
|
|
self.cond_channels = cond_channels
|
|
self.num_blocks = num_blocks
|
|
self.num_heads = num_heads or model_channels // num_head_channels
|
|
self.mlp_ratio = mlp_ratio
|
|
self.share_mod = share_mod
|
|
self.qk_rms_norm_cross = qk_rms_norm_cross
|
|
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
|
|
self.l_embedder = LevelEmbedder(model_channels, dtype=dtype, device=device, operations=operations)
|
|
if share_mod:
|
|
self.adaLN_modulation = nn.Sequential(
|
|
nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, dtype=dtype, device=device))
|
|
if cond_channels is not None:
|
|
self.blocks = nn.ModuleList([
|
|
ModulatedTransformerCrossOnlyBlock(
|
|
model_channels, ctx_channels=cond_channels, num_heads=self.num_heads,
|
|
mlp_ratio=self.mlp_ratio, qk_rms_norm_cross=self.qk_rms_norm_cross,
|
|
share_mod=self.share_mod, dtype=dtype, device=device, operations=operations)
|
|
for _ in range(num_blocks)
|
|
])
|
|
self.out_proj = operations.Linear(model_channels, 8, dtype=dtype, device=device)
|
|
self.in_proj = operations.Linear(3, model_channels, dtype=dtype, device=device)
|
|
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
|
|
|
|
def forward(self, x, l, cond):
|
|
d = next(self.parameters()).dtype
|
|
B, L, _ = x.shape
|
|
h = self.in_proj(x.to(d)) + self.pos_embedder(x.reshape(-1, 3)).reshape(B, L, -1).to(d)
|
|
h = self.input_layer(h)
|
|
l_emb = self.l_embedder(l)
|
|
if self.share_mod:
|
|
l_emb = self.adaLN_modulation(l_emb)
|
|
cond = cond.to(d)
|
|
for block in self.blocks:
|
|
h = block(h, l_emb, cond)
|
|
h = F.layer_norm(h.float(), h.shape[-1:]).to(d)
|
|
logits = self.out_proj(h)
|
|
return {"logits": logits, "probs": torch.softmax(logits, dim=-1)}
|
|
|
|
@staticmethod
|
|
def sample(model, cond, num_points, level, temperature=1.0, generator=None):
|
|
B = cond.shape[0]
|
|
device = cond.device
|
|
child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]],
|
|
dtype=torch.long, device=device)
|
|
prev_coords_int = torch.zeros(B, 1, 3, dtype=torch.long, device=device)
|
|
prev_counts = torch.full((B, 1), num_points, dtype=torch.long, device=device)
|
|
prev_log_probs = torch.zeros(B, 1, dtype=torch.float32, device=device)
|
|
batch_indices_range = torch.arange(B, device=device).unsqueeze(1)
|
|
|
|
for lv in range(1, level + 1):
|
|
res_p = 1 << (lv - 1)
|
|
res = 1 << lv
|
|
parent_coords_norm = (prev_coords_int.to(torch.float32) + 0.5) / res_p
|
|
res_tensor = torch.full((B,), res, dtype=torch.long, device=device)
|
|
pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature
|
|
pred_probs = torch.softmax(pred_logits, dim=-1)
|
|
pred_log_probs = torch.log_softmax(pred_logits, dim=-1)
|
|
sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2)
|
|
pred_log_probs = pred_log_probs.flatten(1, 2)
|
|
prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1)
|
|
child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2)
|
|
mask = sampled > 0
|
|
max_valid = mask.sum(dim=1).max().item()
|
|
scatter_indices = mask.cumsum(dim=1) - 1
|
|
valid_scatter_indices = scatter_indices[mask]
|
|
valid_batch_indices = batch_indices_range.expand_as(mask)[mask]
|
|
next_prev_coords_int = torch.zeros(B, max_valid, 3, dtype=child_coords_int.dtype, device=device)
|
|
next_prev_coords_int[valid_batch_indices, valid_scatter_indices] = child_coords_int[mask]
|
|
next_prev_counts = torch.zeros(B, max_valid, dtype=sampled.dtype, device=device)
|
|
next_prev_counts[valid_batch_indices, valid_scatter_indices] = sampled[mask]
|
|
next_prev_log_probs = torch.zeros(B, max_valid, dtype=prev_log_probs.dtype, device=device)
|
|
next_prev_log_probs[valid_batch_indices, valid_scatter_indices] = (prev_log_probs_expanded + pred_log_probs)[mask]
|
|
prev_coords_int = next_prev_coords_int
|
|
prev_counts = next_prev_counts
|
|
prev_log_probs = next_prev_log_probs
|
|
|
|
res = 1 << level
|
|
prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points)
|
|
coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1)
|
|
rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device)
|
|
coords_norm = (coords_int.to(torch.float32) + rand) / res
|
|
return {"points": coords_norm, "log_probs": prev_log_probs}
|
|
|
|
|
|
# Elastic gaussian decoder
|
|
|
|
class TransformerCrossBlock(nn.Module):
|
|
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0,
|
|
qk_rms_norm=True, qk_rms_norm_cross=True, qkv_bias=True,
|
|
dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
|
self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
self.self_attn = MultiHeadAttention(channels, num_heads=num_heads, type="self", qkv_bias=qkv_bias,
|
|
qk_rms_norm=qk_rms_norm, dtype=dtype, device=device, operations=operations)
|
|
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, type="cross",
|
|
qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
|
|
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward(self, x, context):
|
|
x = x + self.self_attn(self.norm1(x))
|
|
x = x + self.cross_attn(self.norm2(x), context)
|
|
x = x + self.mlp(self.norm3(x))
|
|
return x
|
|
|
|
|
|
class ElasticGaussianFixedlenDecoder(nn.Module):
|
|
# Cross-attention transformer over sampled octree points -> per-point gaussian params.
|
|
def __init__(self, in_channels=3, model_channels=1024, cond_channels=16, num_blocks=16, num_heads=16,
|
|
num_head_channels=64, mlp_ratio=4.0, *, representation_config=None,
|
|
qk_rms_norm=True, qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.rep_config = representation_config or dict(
|
|
lr=dict(_xyz=1.0, _features_dc=1.0, _opacity=1.0, _scaling=1.0, _rotation=0.1),
|
|
perturb_offset=True, perturbe_size=1.5, offset_scale=0.05, num_gaussians=32,
|
|
filter_kernel_size_3d=0.0009, scaling_bias=0.004, opacity_bias=0.1,
|
|
scaling_activation="softplus",
|
|
)
|
|
self.out_channels = self._calc_layout()
|
|
self.model_channels = model_channels
|
|
self.cond_channels = cond_channels
|
|
self.num_blocks = num_blocks
|
|
self.num_heads = num_heads or model_channels // num_head_channels
|
|
self.mlp_ratio = mlp_ratio
|
|
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
|
|
if cond_channels is not None:
|
|
self.blocks = nn.ModuleList([
|
|
TransformerCrossBlock(model_channels, ctx_channels=cond_channels,
|
|
num_heads=self.num_heads, mlp_ratio=self.mlp_ratio,
|
|
qk_rms_norm=qk_rms_norm, qk_rms_norm_cross=qk_rms_norm_cross,
|
|
dtype=dtype, device=device, operations=operations)
|
|
for _ in range(num_blocks)
|
|
])
|
|
self.in_proj = operations.Linear(in_channels, model_channels, dtype=dtype, device=device)
|
|
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
|
|
self.out_proj = operations.Linear(model_channels, self.out_channels, dtype=dtype, device=device)
|
|
self._build_perturbation()
|
|
|
|
def _calc_layout(self):
|
|
ng = self.rep_config['num_gaussians']
|
|
self.layout = {
|
|
'_xyz': {'shape': (ng, 3), 'size': ng * 3},
|
|
'_features_dc': {'shape': (ng, 1, 3), 'size': ng * 3},
|
|
'_scaling': {'shape': (ng, 3), 'size': ng * 3},
|
|
'_rotation': {'shape': (ng, 4), 'size': ng * 4},
|
|
'_opacity': {'shape': (ng, 1), 'size': ng},
|
|
}
|
|
self.layout['_offset_scale'] = {'shape': (ng, 1), 'size': ng}
|
|
start = 0
|
|
for k, v in self.layout.items():
|
|
v['range'] = (start, start + v['size'])
|
|
start += v['size']
|
|
return start
|
|
|
|
def _build_perturbation(self):
|
|
ng = self.rep_config['num_gaussians']
|
|
perturbation = torch.tensor([hammersley_sequence(3, i, ng) for i in range(ng)]).float()
|
|
perturbation = torch.atanh((perturbation * 2 - 1) / self.rep_config['perturbe_size'])
|
|
self.register_buffer('points_offset_perturbation', perturbation)
|
|
base = torch.tensor(self.rep_config['offset_scale'])
|
|
self.register_buffer('base_offset_scale', torch.log(torch.exp(base) - 1.0))
|
|
|
|
def _get_offset(self, h):
|
|
B = h.shape[0]
|
|
r = self.layout['_offset_scale']['range']
|
|
_offset_scale = F.softplus(
|
|
h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_offset_scale']['shape'])
|
|
+ comfy.model_management.cast_to(self.base_offset_scale, h.dtype, h.device))
|
|
|
|
r = self.layout['_xyz']['range']
|
|
offset = h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_xyz']['shape'])
|
|
offset = offset * self.rep_config['lr']['_xyz']
|
|
if self.rep_config['perturb_offset']:
|
|
offset = offset + comfy.model_management.cast_to(self.points_offset_perturbation, offset.dtype, offset.device)
|
|
offset = torch.tanh(offset) * 0.5 * self.rep_config['perturbe_size']
|
|
offset = offset * _offset_scale
|
|
return offset
|
|
|
|
def forward(self, x=None, cond=None):
|
|
pcd = x["points"]
|
|
d = next(self.parameters()).dtype
|
|
B, L, _ = pcd.shape
|
|
h = self.in_proj(pcd.to(d)) + self.pos_embedder(pcd.reshape(-1, 3)).reshape(B, L, -1).to(d)
|
|
h = self.input_layer(h)
|
|
cond = cond.to(d)
|
|
for block in self.blocks:
|
|
h = block(h, cond)
|
|
h = F.layer_norm(h.float(), h.shape[-1:]).to(h.dtype)
|
|
return {"features": self.out_proj(h)}
|
|
|
|
|
|
# Combined octree gaussian decoder (comfy first-stage model)
|
|
|
|
class OctreeGaussianDecoder(nn.Module):
|
|
_MAX_VOXEL_LEVEL = 8
|
|
|
|
def __init__(self, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
if operations is None:
|
|
operations = comfy.ops.disable_weight_init
|
|
self.octree = OctreeProbabilityFixedlenDecoder(dtype=dtype, device=device, operations=operations)
|
|
self.gs = ElasticGaussianFixedlenDecoder(dtype=dtype, device=device, operations=operations)
|
|
|
|
@property
|
|
def gaussians_per_point(self) -> int:
|
|
return self.gs.rep_config['num_gaussians']
|
|
|
|
def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None):
|
|
# level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews.
|
|
# generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG.
|
|
level = self._MAX_VOXEL_LEVEL if level is None else level
|
|
num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point)
|
|
points_pred = OctreeProbabilityFixedlenDecoder.sample(
|
|
self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator,
|
|
)
|
|
pred = self.gs(x=points_pred, cond=latent)
|
|
return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item
|