initial Pixal3D support

This commit is contained in:
kijai 2026-05-22 01:50:48 +03:00
parent 5b981ed295
commit 9e4794da5c
9 changed files with 1316 additions and 123 deletions

View File

@ -5,8 +5,12 @@ from typing import Tuple, Union, List
from comfy.ldm.trellis2.vae import VarLenTensor
import comfy.ops
try:
from torch.nn.attention.varlen import varlen_attn as _varlen_attn
except ImportError:
_varlen_attn = None
# replica of the seedvr2 code
def var_attn_arg(kwargs):
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
max_seqlen_q = kwargs.get("max_seqlen_q", None)
@ -16,42 +20,30 @@ def var_attn_arg(kwargs):
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
var_length = True
if var_length:
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
if not skip_reshape:
# assumes 2D q, k,v [total_tokens, embed_dim]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
q = q.view(total_tokens, heads, head_dim)
k = k.view(k.shape[0], heads, head_dim)
v = v.view(v.shape[0], heads, head_dim)
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
if not skip_reshape:
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
q = q.view(total_tokens, heads, head_dim)
k = k.view(k.shape[0], heads, head_dim)
v = v.view(v.shape[0], heads, head_dim)
b = q.size(0)
dim_head = q.shape[-1]
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
mask = None
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if mask is not None:
if mask.ndim == 2:
mask = mask.unsqueeze(0)
if mask.ndim == 3:
mask = mask.unsqueeze(1)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if var_length:
return out.transpose(1, 2).values()
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
if _varlen_attn is not None:
return _varlen_attn(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
int(max_seqlen_q), int(max_seqlen_k),
)
return out
# Fallback: nested-tensor SDPA (PyTorch < the version that introduced varlen_attn)
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
return out.transpose(1, 2).values()
def scaled_dot_product_attention(*args, **kwargs):
num_all_args = len(args) + len(kwargs)

View File

@ -26,16 +26,26 @@ class TorchHashMap:
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
self._n = self.sorted_keys.numel()
# Chunk size for lookup_flat. At ~530M flat keys (large mesh extraction),
# the unchunked path allocates ~5 full-size int64 temporaries (4 GB each) +
# bool masks + the int32 output. Chunking caps each transient to ~CHUNK rows.
_LOOKUP_CHUNK = 1 << 23 # 8M rows ≈ 64 MB per int64 temp
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
flat = flat_keys.to(torch.long)
if self._n == 0:
return torch.full((flat.shape[0],), -1, device=flat.device, dtype=torch.int32)
idx = torch.searchsorted(self.sorted_keys, flat)
idx_safe = torch.clamp(idx, max=self._n - 1)
found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat)
out = torch.full((flat.shape[0],), -1, device=flat.device, dtype=torch.int32)
if found.any():
out[found] = self.sorted_vals[idx_safe[found]].to(torch.int32)
N = flat_keys.shape[0]
out = torch.full((N,), -1, device=flat_keys.device, dtype=torch.int32)
if self._n == 0 or N == 0:
return out
for s in range(0, N, self._LOOKUP_CHUNK):
e = min(s + self._LOOKUP_CHUNK, N)
flat_chunk = flat_keys[s:e].to(torch.long)
idx = torch.searchsorted(self.sorted_keys, flat_chunk)
in_range = idx < self._n
idx.clamp_(max=self._n - 1) # reuse idx as the "safe" index
found = in_range & (self.sorted_keys[idx] == flat_chunk)
if found.any():
found_idx = found.nonzero(as_tuple=True)[0]
out[s + found_idx] = self.sorted_vals[idx[found_idx]].to(torch.int32)
return out
@ -212,10 +222,10 @@ def sparse_submanifold_conv3d(
if accumulate_f32:
weight_T = weight.view(Co, V * Ci).to(torch.float32).T.contiguous()
output = torch.zeros(N_pts, Co, device=device, dtype=torch.float32)
else:
weight_T = weight.view(Co, V * Ci).to(feats.dtype).T.contiguous()
output = torch.zeros(N_pts, Co, device=device, dtype=feats.dtype)
output = torch.empty(N_pts, Co, device=device, dtype=feats.dtype)
# ------------------------------------------------------------------
# Chunk size from memory budget
@ -226,6 +236,9 @@ def sparse_submanifold_conv3d(
chunk_size = max(1, int(max_chunk_mem / mem_per_row))
chunk_size = min(chunk_size, N_pts)
# fp32 matmul scratch — sized to the largest chunk, reused each iteration.
chunk_buf = torch.empty(chunk_size, Co, device=device, dtype=torch.float32) if accumulate_f32 else None
# ------------------------------------------------------------------
# Chunked forward pass
# Each iteration:
@ -233,7 +246,8 @@ def sparse_submanifold_conv3d(
# 2. mask zero invalids in-place, no extra alloc
# 3. reshape (chunk, V*Ci)
# 4. GEMM (chunk, V*Ci) @ (V*Ci, Co) → (chunk, Co) cuBLAS
# written directly into output slice via out= argument
# written into the scratch buf (fp32) or output slice (fp16) via out=
# 5. (fp32 path) cast scratch chunk to fp16 and copy into output slice
# ------------------------------------------------------------------
for start in range(0, N_pts, chunk_size):
end = min(start + chunk_size, N_pts)
@ -257,16 +271,13 @@ def sparse_submanifold_conv3d(
gathered_flat = gathered.view(actual_chunk, V * Ci)
if accumulate_f32:
gathered_flat = gathered_flat.to(torch.float32)
# Single GEMM call per chunk, written directly into output.
# This avoids allocating a temporary (chunk, Co) tensor.
torch.matmul(gathered_flat, weight_T, out=output[start:end])
if accumulate_f32:
output = output.to(feats.dtype)
torch.matmul(gathered_flat, weight_T, out=chunk_buf[:actual_chunk])
output[start:end] = chunk_buf[:actual_chunk].to(feats.dtype)
else:
torch.matmul(gathered_flat, weight_T, out=output[start:end])
if bias is not None:
output = output + bias.unsqueeze(0).to(output.dtype)
output += bias.unsqueeze(0).to(output.dtype)
return output, neighbor

View File

@ -25,15 +25,12 @@ class SparseFeedForwardNet(nn.Module):
def forward(self, x: VarLenTensor) -> VarLenTensor:
return self.mlp(x)
def manual_cast(obj, dtype):
return obj.to(dtype=dtype)
class LayerNorm32(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dtype = x.dtype
x = manual_cast(x, torch.float32)
x = x.to(dtype=torch.float32)
o = super().forward(x)
return manual_cast(o, x_dtype)
return o.to(dtype=x_dtype)
class SparseMultiHeadRMSNorm(nn.Module):
@ -249,6 +246,51 @@ class SparseMultiHeadAttention(nn.Module):
h = self._linear(self.to_out, h)
return h
def _split_proj_context(context):
if not isinstance(context, dict):
return context, None
global_ctx = context["global"]
if "proj" in context:
return global_ctx, context["proj"]
if "proj_semantic" in context and "proj_color" in context:
return global_ctx, (context["proj_semantic"], context["proj_color"])
return global_ctx, None
class ProjectAttentionSparse(nn.Module):
def __init__(self, cross_attn_block: nn.Module, channels: int, proj_in_channels: int,
device=None, dtype=None, operations=None):
super().__init__()
self.cross_attn_block = cross_attn_block
self.proj_linear = operations.Linear(proj_in_channels, channels, bias=True,
device=device, dtype=dtype)
def forward(self, x: SparseTensor, context) -> SparseTensor:
global_ctx, proj_in = _split_proj_context(context)
global_out = self.cross_attn_block(x, global_ctx)
if isinstance(proj_in, tuple):
proj_in = torch.cat([proj_in[0], proj_in[1]], dim=-1)
proj_out = self.proj_linear(proj_in.to(self.proj_linear.weight.dtype))
return global_out.replace(global_out.feats + proj_out.to(global_out.feats.dtype))
class ProjectAttentionDense(nn.Module):
def __init__(self, cross_attn_block: nn.Module, channels: int, proj_in_channels: int,
device=None, dtype=None, operations=None):
super().__init__()
self.cross_attn_block = cross_attn_block
self.proj_linear = operations.Linear(proj_in_channels, channels, bias=True,
device=device, dtype=dtype)
def forward(self, x: torch.Tensor, context) -> torch.Tensor:
global_ctx, proj_in = _split_proj_context(context)
global_out = self.cross_attn_block(x, global_ctx)
if isinstance(proj_in, tuple):
proj_in = torch.cat([proj_in[0], proj_in[1]], dim=-1)
proj_out = self.proj_linear(proj_in.to(self.proj_linear.weight.dtype))
return global_out + proj_out.to(global_out.dtype)
class ModulatedSparseTransformerCrossBlock(nn.Module):
"""
Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
@ -269,11 +311,14 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
qk_rms_norm_cross: bool = False,
qkv_bias: bool = True,
share_mod: bool = False,
image_attn_mode: Literal["global", "proj", "gated_proj"] = "global",
proj_in_channels: Optional[int] = None,
device=None, dtype=None, operations=None
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.image_attn_mode = image_attn_mode
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
@ -290,7 +335,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
qk_rms_norm=qk_rms_norm,
device=device, dtype=dtype, operations=operations
)
self.cross_attn = SparseMultiHeadAttention(
cross_inner = SparseMultiHeadAttention(
channels,
ctx_channels=ctx_channels,
num_heads=num_heads,
@ -300,6 +345,15 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
qk_rms_norm=qk_rms_norm_cross,
device=device, dtype=dtype, operations=operations
)
if image_attn_mode == "global":
self.cross_attn = cross_inner
else:
if proj_in_channels is None:
raise ValueError("proj_in_channels must be set when image_attn_mode != 'global'")
self.cross_attn = ProjectAttentionSparse(
cross_inner, channels, proj_in_channels,
device=device, dtype=dtype, operations=operations,
)
self.mlp = SparseFeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
@ -313,7 +367,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
else:
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
def _forward(self, x: SparseTensor, mod: torch.Tensor, context) -> SparseTensor:
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
else:
@ -324,7 +378,11 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
h = h * gate_msa
x = x + h
h = x.replace(self.norm2(x.feats))
h = self.cross_attn(h, context)
if self.image_attn_mode == "global":
global_ctx, _ = _split_proj_context(context)
h = self.cross_attn(h, global_ctx)
else:
h = self.cross_attn(h, context)
x = x + h
h = x.replace(self.norm3(x.feats))
h = h * (1 + scale_mlp) + shift_mlp
@ -333,7 +391,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
x = x + h
return x
def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
def forward(self, x: SparseTensor, mod: torch.Tensor, context) -> SparseTensor:
return self._forward(x, mod, context)
@ -356,6 +414,8 @@ class SLatFlowModel(nn.Module):
initialization: str = 'vanilla',
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
image_attn_mode: Literal["global", "proj", "gated_proj"] = "global",
proj_in_channels: Optional[int] = None,
dtype = None,
device = None,
operations = None,
@ -375,6 +435,8 @@ class SLatFlowModel(nn.Module):
self.initialization = initialization
self.qk_rms_norm = qk_rms_norm
self.qk_rms_norm_cross = qk_rms_norm_cross
self.image_attn_mode = image_attn_mode
self.proj_in_channels = proj_in_channels
self.dtype = dtype
self.t_embedder = TimestepEmbedder(model_channels, device=device, dtype=dtype, operations=operations)
@ -399,6 +461,8 @@ class SLatFlowModel(nn.Module):
share_mod=self.share_mod,
qk_rms_norm=self.qk_rms_norm,
qk_rms_norm_cross=self.qk_rms_norm_cross,
image_attn_mode=image_attn_mode,
proj_in_channels=proj_in_channels,
device=device, dtype=dtype, operations=operations
)
for _ in range(num_blocks)
@ -426,19 +490,15 @@ class SLatFlowModel(nn.Module):
dtype = next(self.input_layer.parameters()).dtype
x = x.to(dtype)
h = self.input_layer(x)
h = manual_cast(h, self.dtype)
t = t.to(dtype)
t_embedder = self.t_embedder.to(dtype)
t_emb = t_embedder(t, out_dtype = t.dtype)
if self.share_mod:
t_emb = self.adaLN_modulation(t_emb)
t_emb = manual_cast(t_emb, self.dtype)
cond = manual_cast(cond, self.dtype)
for block in self.blocks:
h = block(h, t_emb, cond)
h = manual_cast(h, x.dtype)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.out_layer(h)
return h
@ -561,11 +621,14 @@ class ModulatedTransformerCrossBlock(nn.Module):
qk_rms_norm_cross: bool = False,
qkv_bias: bool = True,
share_mod: bool = False,
image_attn_mode: Literal["global", "proj", "gated_proj"] = "global",
proj_in_channels: Optional[int] = None,
device=None, dtype=None, operations=None
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.image_attn_mode = image_attn_mode
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
@ -582,7 +645,7 @@ class ModulatedTransformerCrossBlock(nn.Module):
qk_rms_norm=qk_rms_norm,
device=device, dtype=dtype, operations=operations
)
self.cross_attn = MultiHeadAttention(
cross_inner = MultiHeadAttention(
channels,
ctx_channels=ctx_channels,
num_heads=num_heads,
@ -592,6 +655,15 @@ class ModulatedTransformerCrossBlock(nn.Module):
qk_rms_norm=qk_rms_norm_cross,
device=device, dtype=dtype, operations=operations
)
if image_attn_mode == "global":
self.cross_attn = cross_inner
else:
if proj_in_channels is None:
raise ValueError("proj_in_channels must be set when image_attn_mode != 'global'")
self.cross_attn = ProjectAttentionDense(
cross_inner, channels, proj_in_channels,
device=device, dtype=dtype, operations=operations,
)
self.mlp = FeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
@ -605,7 +677,7 @@ class ModulatedTransformerCrossBlock(nn.Module):
else:
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
else:
@ -616,7 +688,11 @@ class ModulatedTransformerCrossBlock(nn.Module):
h = h * gate_msa.unsqueeze(1)
x = x + h
h = self.norm2(x)
h = self.cross_attn(h, context)
if self.image_attn_mode == "global":
global_ctx, _ = _split_proj_context(context)
h = self.cross_attn(h, global_ctx)
else:
h = self.cross_attn(h, context)
x = x + h
h = self.norm3(x)
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
@ -625,7 +701,7 @@ class ModulatedTransformerCrossBlock(nn.Module):
x = x + h
return x
def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, x: torch.Tensor, mod: torch.Tensor, context, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
return self._forward(x, mod, context, phases)
@ -648,6 +724,8 @@ class SparseStructureFlowModel(nn.Module):
initialization: str = 'vanilla',
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
image_attn_mode: Literal["global", "proj", "gated_proj"] = "global",
proj_in_channels: Optional[int] = None,
operations=None,
device = None,
dtype = torch.float32,
@ -669,6 +747,8 @@ class SparseStructureFlowModel(nn.Module):
self.initialization = initialization
self.qk_rms_norm = qk_rms_norm
self.qk_rms_norm_cross = qk_rms_norm_cross
self.image_attn_mode = image_attn_mode
self.proj_in_channels = proj_in_channels
self.dtype = dtype
self.device = device
@ -703,6 +783,8 @@ class SparseStructureFlowModel(nn.Module):
share_mod=share_mod,
qk_rms_norm=self.qk_rms_norm,
qk_rms_norm_cross=self.qk_rms_norm_cross,
image_attn_mode=image_attn_mode,
proj_in_channels=proj_in_channels,
device=device, dtype=dtype, operations=operations
)
for _ in range(num_blocks)
@ -720,14 +802,9 @@ class SparseStructureFlowModel(nn.Module):
t_emb = self.t_embedder(t, out_dtype = t.dtype)
if self.share_mod:
t_emb = self.adaLN_modulation(t_emb)
t_emb = manual_cast(t_emb, self.dtype)
h = manual_cast(h, self.dtype)
cond = manual_cast(cond, self.dtype)
for block in self.blocks:
h = block(h, t_emb, cond, self.rope_phases)
h = manual_cast(h, x.dtype)
h = F.layer_norm(h, h.shape[-1:])
h = h.to(next(self.out_layer.parameters()).dtype)
h = self.out_layer(h)
h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous()
@ -741,6 +818,221 @@ def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
t_new *= 1000.0
return t_new
# Pixal3D ProjGrid math — port of upstream's ProjGrid + project_points_to_image_batch.
# World frame uses world Y as depth (Blender convention), camera looks along -Z local;
# transform_matrix is camera-to-world (inverted internally). Intrinsics: fx = 16 / tan(fov/2)
# with sensor_width = 32mm.
_PROJ_GRID_ROTATION = torch.tensor(
[[1.0, 0.0, 0.0],
[0.0, 0.0, -1.0],
[0.0, 1.0, 0.0]]
)
_PROJ_FRONT_VIEW_TRANSFORM = torch.tensor(
[[1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, -1.0, -2.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0]]
)
def _build_proj_transform_matrix(distance: torch.Tensor, batch_size: int,
device, dtype=torch.float32) -> torch.Tensor:
T = _PROJ_FRONT_VIEW_TRANSFORM.to(device=device, dtype=dtype)
T = T.unsqueeze(0).expand(batch_size, -1, -1).clone()
if distance.ndim == 0:
distance = distance.expand(batch_size)
T[:, 1, 3] = -distance.to(device=device, dtype=dtype)
return T
def _project_points_to_image(points_world: torch.Tensor, transform_matrix: torch.Tensor,
camera_angle_x: torch.Tensor, resolution: int):
B, N, _ = points_world.shape
ones = torch.ones((B, N, 1), device=points_world.device, dtype=points_world.dtype)
homo = torch.cat([points_world, ones], dim=-1)
world_to_camera = torch.linalg.inv(transform_matrix.float()).to(transform_matrix.dtype)
p_cam = torch.bmm(homo, world_to_camera.transpose(-2, -1))[..., :3]
x_cam, y_cam, z_cam = p_cam.unbind(dim=-1)
depth = -z_cam
sensor_width = 32.0
focal_length = 16.0 / torch.tan(camera_angle_x / 2.0)
focal_px = focal_length * resolution / sensor_width
focal_px = focal_px.to(p_cam.dtype).unsqueeze(1)
denom = (-z_cam + 1e-8)
x_pix = focal_px * x_cam / denom + resolution / 2.0
y_pix = -focal_px * y_cam / denom + resolution / 2.0
valid = ((x_pix >= 0) & (x_pix < resolution) &
(y_pix >= 0) & (y_pix < resolution) & (depth > 0))
return torch.stack([x_pix, y_pix], dim=-1), depth, valid
def _sample_features(feature_map: torch.Tensor, uv_ndc: torch.Tensor) -> torch.Tensor:
B, C, _, _ = feature_map.shape
grid = uv_ndc.view(B, -1, 1, 2).to(feature_map.dtype)
feat = F.grid_sample(feature_map, grid, mode="bilinear",
padding_mode="border", align_corners=False)
return feat.squeeze(-1)
def _coords_to_proj_world(coords: torch.Tensor, resolution: int, mesh_scale: torch.Tensor):
if resolution < 1:
raise ValueError(f"resolution must be positive, got {resolution}")
batch_ids = coords[:, 0].long()
if resolution == 1:
norm = coords[:, 1:].to(torch.float32) * 0.0
else:
norm = coords[:, 1:].to(torch.float32) / (resolution - 1) * 2.0 - 1.0
R = _PROJ_GRID_ROTATION.to(device=coords.device, dtype=torch.float32)
rotated = norm @ R.T
if mesh_scale.ndim == 0:
scale_per_voxel = mesh_scale.expand(coords.shape[0])
else:
scale_per_voxel = mesh_scale.to(coords.device)[batch_ids]
world = rotated / scale_per_voxel.unsqueeze(-1) / 2.0
return world, batch_ids
def _dense_grid_proj_world(resolution: int, mesh_scale: torch.Tensor,
batch_size: int, device, dtype=torch.float32) -> torch.Tensor:
one = torch.linspace(-1.0, 1.0, resolution, device=device, dtype=dtype)
x, y, z = torch.meshgrid(one, one, one, indexing="ij")
grid = torch.stack([x, y, z], dim=-1).reshape(-1, 3)
R_rot = _PROJ_GRID_ROTATION.to(device=device, dtype=dtype)
grid = grid @ R_rot.T
grid = grid.unsqueeze(0).expand(batch_size, -1, -1).clone()
if mesh_scale.ndim == 0:
mesh_scale = mesh_scale.expand(batch_size)
grid = grid / mesh_scale.to(device=device, dtype=dtype).view(-1, 1, 1) / 2.0
return grid
def _back_project_to_tokens(
coords_world: torch.Tensor,
feature_map: torch.Tensor,
transform_matrix: torch.Tensor,
camera_angle_x: torch.Tensor,
image_resolution: int,
batch_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if coords_world.dim() == 2:
assert batch_ids is not None
B = transform_matrix.shape[0]
out = torch.zeros((coords_world.shape[0], feature_map.shape[1]),
device=feature_map.device, dtype=feature_map.dtype)
for b in range(B):
mask = batch_ids == b
if not mask.any():
continue
p = coords_world[mask].unsqueeze(0)
uv, depth, valid = _project_points_to_image(
p, transform_matrix[b:b+1], camera_angle_x[b:b+1], image_resolution)
uv_ndc = (uv + 0.5) / image_resolution * 2.0 - 1.0
# padding_mode='border' is load-bearing: masking out-of-frame voxels confuses
# the SS DiT (~half the voxels go to zero, producing low poly + rotation drift).
sampled = _sample_features(feature_map[b:b+1], uv_ndc)
sampled = sampled.squeeze(0).transpose(0, 1)
out[mask] = sampled
return out
else:
uv, depth, valid = _project_points_to_image(
coords_world, transform_matrix, camera_angle_x, image_resolution)
uv_ndc = (uv + 0.5) / image_resolution * 2.0 - 1.0
sampled = _sample_features(feature_map, uv_ndc)
out = sampled.transpose(1, 2)
return out
def _pack_per_voxel_scalar(proj_pack: Optional[dict], key: str, eval_batch: int, device) -> torch.Tensor:
if proj_pack is None or key not in proj_pack:
return torch.ones((eval_batch,), device=device, dtype=torch.float32)
t = proj_pack[key].to(device=device, dtype=torch.float32)
if t.ndim == 0:
return t.expand(eval_batch).clone()
return _expand_pack(t, eval_batch)
def _expand_pack(t: torch.Tensor, eval_batch: int) -> torch.Tensor:
if eval_batch == t.shape[0]:
return t
if eval_batch % t.shape[0] != 0:
raise ValueError(f"eval batch {eval_batch} is not a multiple of pack batch {t.shape[0]}")
return t.repeat((eval_batch // t.shape[0],) + (1,) * (t.ndim - 1))
def _select_stage_entry(proj_pack: dict, stage: Optional[str]):
"""Returns (feature_map_lr, feature_map_hr_or_None, image_resolution)."""
stages = proj_pack.get("stages")
if stages is not None and stage is not None and stage in stages:
entry = stages[stage]
return entry["feature_map"], entry.get("feature_map_hr"), int(entry.get("image_resolution", 1024))
if "feature_map" in proj_pack:
return proj_pack["feature_map"], proj_pack.get("feature_map_hr"), int(proj_pack.get("image_resolution", 1024))
raise ValueError(f"proj_feat_pack has no usable feature_map (stage={stage!r})")
def _build_proj_cond(global_cond: torch.Tensor, image_attn_mode: str, proj_pack: Optional[dict],
coords_world: torch.Tensor, batch_ids: Optional[torch.Tensor] = None,
eval_batch: Optional[int] = None,
proj_in_channels: Optional[int] = None,
stage: Optional[str] = None,
cond_or_uncond: Optional[list] = None):
if image_attn_mode == "global":
return global_cond
if proj_pack is None:
raise ValueError(f"image_attn_mode={image_attn_mode!r} but proj_feat_pack is missing")
device = coords_world.device
T = proj_pack["transform_matrix"].to(device)
cam_angle = proj_pack["camera_angle_x"].to(device)
feat_map_lr, feat_map_hr, image_resolution = _select_stage_entry(proj_pack, stage)
feat_map_lr = feat_map_lr.to(device)
if feat_map_hr is not None:
feat_map_hr = feat_map_hr.to(device)
if eval_batch is not None:
T = _expand_pack(T, eval_batch)
cam_angle = _expand_pack(cam_angle, eval_batch) if cam_angle.ndim >= 1 else cam_angle
feat_map_lr = _expand_pack(feat_map_lr, eval_batch)
if feat_map_hr is not None:
feat_map_hr = _expand_pack(feat_map_hr, eval_batch)
# Channel-count check against the trained proj_linear input. If HR is present, the
# block expects (LR_channels + HR_channels) since we concat the sampled features.
expected_channels = feat_map_lr.shape[1] + (feat_map_hr.shape[1] if feat_map_hr is not None else 0)
if proj_in_channels is not None and expected_channels != proj_in_channels:
hint = ""
if feat_map_hr is None and expected_channels < proj_in_channels:
hint = (" — feature_map_hr is missing for this stage. Connect a NAFModel "
"input to Pixal3DConditioning; the shape/texture stages of this "
"checkpoint need a NAF-upsampled HR feature map.")
raise ValueError(
f"proj_feat_pack[{stage!r}] has LR={feat_map_lr.shape[1]} "
f"+ HR={feat_map_hr.shape[1] if feat_map_hr is not None else 0} "
f"= {expected_channels} channels, sub-model expects {proj_in_channels}.{hint}"
)
proj_feats_lr = _back_project_to_tokens(coords_world, feat_map_lr, T, cam_angle,
image_resolution=image_resolution,
batch_ids=batch_ids)
if feat_map_hr is not None:
proj_feats_hr = _back_project_to_tokens(coords_world, feat_map_hr, T, cam_angle,
image_resolution=image_resolution,
batch_ids=batch_ids)
proj_feats = torch.cat([proj_feats_lr, proj_feats_hr], dim=-1)
else:
proj_feats = proj_feats_lr
# Mirror upstream's neg_cond by zeroing proj for any uncond batch slot.
if cond_or_uncond is not None and eval_batch is not None:
uncond_slots = [i for i, v in enumerate(cond_or_uncond) if v == 1]
if uncond_slots:
uncond_idx = torch.tensor(uncond_slots, device=proj_feats.device, dtype=torch.long)
if batch_ids is None:
proj_feats = proj_feats.clone()
proj_feats[uncond_idx] = 0
else:
neg_mask = torch.isin(batch_ids, uncond_idx).unsqueeze(-1).to(proj_feats.dtype)
proj_feats = proj_feats * (1.0 - neg_mask)
return {"global": global_cond, "proj": proj_feats}
class Trellis2(nn.Module):
def __init__(self, resolution,
in_channels = 32,
@ -754,6 +1046,12 @@ class Trellis2(nn.Module):
qk_rms_norm = True,
qk_rms_norm_cross = True,
init_txt_model=False, # for now
image_attn_mode_structure: str = "global",
proj_in_channels_structure: Optional[int] = None,
image_attn_mode_shape: str = "global",
proj_in_channels_shape: Optional[int] = None,
image_attn_mode_texture: str = "global",
proj_in_channels_texture: Optional[int] = None,
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
@ -767,22 +1065,29 @@ class Trellis2(nn.Module):
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
}
self.image_attn_mode_structure = image_attn_mode_structure
self.image_attn_mode_shape = image_attn_mode_shape
self.image_attn_mode_texture = image_attn_mode_texture
shape_proj_kwargs = {"image_attn_mode": image_attn_mode_shape, "proj_in_channels": proj_in_channels_shape}
tex_proj_kwargs = {"image_attn_mode": image_attn_mode_texture, "proj_in_channels": proj_in_channels_texture}
struct_proj_kwargs = {"image_attn_mode": image_attn_mode_structure, "proj_in_channels": proj_in_channels_structure}
txt_only = kwargs.get("txt_only", False)
if not txt_only:
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **shape_proj_kwargs, **args)
self.shape2txt = None
if init_txt_model:
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args)
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **tex_proj_kwargs, **args)
self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **shape_proj_kwargs, **args)
args.pop("out_channels")
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **struct_proj_kwargs, **args)
else:
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **tex_proj_kwargs, **args)
self.guidance_interval = [0.6, 1.0]
self.guidance_interval_txt = [0.6, 0.9]
def forward(self, x, timestep, context, **kwargs):
transformer_options = kwargs.get("transformer_options", {})
cond_or_uncond = transformer_options.get("cond_or_uncond")
model_options = {}
if hasattr(self, "meta"):
model_options = self.meta
@ -795,6 +1100,8 @@ class Trellis2(nn.Module):
coords = model_options.get("coords", None)
coord_counts = model_options.get("coord_counts", None)
mode = model_options.get("generation_mode", "structure_generation")
proj_feat_pack = model_options.get("proj_feat_pack", None)
coord_resolution = model_options.get("coord_resolution", None)
is_512_run = False
if mode == "shape_generation_512":
@ -884,6 +1191,20 @@ class Trellis2(nn.Module):
x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32))
if mode == "shape_generation":
shape_attn = self.image_attn_mode_shape
if shape_attn != "global":
if coord_resolution is None:
raise ValueError("Pixal3D shape_generation requires coord_resolution in model_options; "
"EmptyTrellis2ShapeLatent should set it from the input voxel.")
mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", B, batched_coords.device)
xyz_world, batch_ids = _coords_to_proj_world(batched_coords, coord_resolution, mesh_scale)
sub_model = self.img2shape_512 if is_512_run else self.img2shape
stage_name = "shape_512" if is_512_run else "shape_1024"
c_eval = _build_proj_cond(c_eval, shape_attn, proj_feat_pack, xyz_world, batch_ids,
eval_batch=B,
proj_in_channels=sub_model.proj_in_channels,
stage=stage_name,
cond_or_uncond=cond_or_uncond)
if is_512_run:
out = self.img2shape_512(x_st, t_eval, c_eval)
else:
@ -904,18 +1225,49 @@ class Trellis2(nn.Module):
slat_feats = slat_feats[:N].repeat(B, 1)
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats.to(x_st.feats.device)], dim=-1))
tex_attn = self.image_attn_mode_texture
if tex_attn != "global":
if coord_resolution is None:
raise ValueError("Pixal3D texture_generation requires coord_resolution in model_options; "
"EmptyTrellis2LatentTexture should set it from the input voxel.")
mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", B, batched_coords.device)
xyz_world, batch_ids = _coords_to_proj_world(batched_coords, coord_resolution, mesh_scale)
c_eval = _build_proj_cond(c_eval, tex_attn, proj_feat_pack, xyz_world, batch_ids,
eval_batch=B,
proj_in_channels=self.shape2txt.proj_in_channels,
stage="tex_1024",
cond_or_uncond=cond_or_uncond)
out = self.shape2txt(x_st, t_eval, c_eval)
else: # structure
orig_bsz = x.shape[0]
struct_attn = self.image_attn_mode_structure
if shape_rule and orig_bsz > 1:
half = orig_bsz // 2
x_eval = x[half:]
t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep
out = self.structure_model(x_eval, t_eval, cond)
struct_cond = cond
if struct_attn != "global":
mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", half, x.device)
grid_xyz = _dense_grid_proj_world(16, mesh_scale, half, device=x.device)
struct_cond = _build_proj_cond(cond, struct_attn, proj_feat_pack, grid_xyz,
eval_batch=half,
proj_in_channels=self.structure_model.proj_in_channels,
stage="ss",
cond_or_uncond=cond_or_uncond)
out = self.structure_model(x_eval, t_eval, struct_cond)
out = out.repeat(2, 1, 1, 1, 1)
else:
out = self.structure_model(x, timestep, context)
struct_cond = context
if struct_attn != "global":
mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", orig_bsz, x.device)
grid_xyz = _dense_grid_proj_world(16, mesh_scale, orig_bsz, device=x.device)
struct_cond = _build_proj_cond(context, struct_attn, proj_feat_pack, grid_xyz,
eval_batch=orig_bsz,
proj_in_channels=self.structure_model.proj_in_channels,
stage="ss",
cond_or_uncond=cond_or_uncond)
out = self.structure_model(x, timestep, struct_cond)
if not_struct_mode:
if mask is not None:

View File

@ -0,0 +1,301 @@
"""NAF (Neighborhood Attention Filtering) feature upsampler.
Vendored from valeoai/NAF (Apache-2.0):
https://github.com/valeoai/NAF src/model/naf.py + src/layers/{convolutions,attentions,rope}.py
Used by Pixal3D's shape/texture conditioning to produce
the 2x-upsampled half of the 2048-channel proj feature map.
"""
import math
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# Pure-torch neighborhood attention (replaces natten.na2d / na2d_qk + na2d_av).
def upsample_lr_slice(src_lr: torch.Tensor, lr_dh: int, lr_dw: int,
hr_h_range: Tuple[int, int], hr_w_range: Tuple[int, int]) -> torch.Tensor:
"""Slice a LR-layout tensor [B, h_lr, w_lr, n, C], permute to BCHW, and
nearest-exact upsample only the region covering [hr_h_range, hr_w_range].
Returns BCHW at hr_h_end-hr_h_start x hr_w_end-hr_w_start (no padding for
out-of-bounds regions)."""
B = src_lr.shape[0]
n = src_lr.shape[-2]
C = src_lr.shape[-1]
h_hr_start, h_hr_end = hr_h_range
w_hr_start, w_hr_end = hr_w_range
# LR positions covering [h_hr_start, h_hr_end). Nearest-exact maps HR p → p // D.
lr_h_start = h_hr_start // lr_dh
lr_h_end = (h_hr_end - 1) // lr_dh + 1
lr_w_start = w_hr_start // lr_dw
lr_w_end = (w_hr_end - 1) // lr_dw + 1
lr_slice = src_lr[:, lr_h_start:lr_h_end, lr_w_start:lr_w_end]
lh, lw = lr_slice.shape[1], lr_slice.shape[2]
lr_bcd = lr_slice.permute(0, 3, 4, 1, 2).reshape(B * n, C, lh, lw).contiguous()
up = F.interpolate(lr_bcd, scale_factor=(lr_dh, lr_dw), mode="nearest-exact")
offset_h = h_hr_start - lr_h_start * lr_dh
offset_w = w_hr_start - lr_w_start * lr_dw
return up[:, :, offset_h:offset_h + (h_hr_end - h_hr_start),
offset_w:offset_w + (w_hr_end - w_hr_start)]
def na2d_pure(
q: torch.Tensor, # [B, H, W, n_heads, d_qk] at HR.
k_lr: torch.Tensor, # [B, h_lr, w_lr, n_heads, d_qk] at LR
v_lr: torch.Tensor, # [B, h_lr, w_lr, n_heads, d_v] at LR
kernel_size: Tuple[int, int], # (Kh, Kw) attention window.
dilation: Tuple[int, int], # (Dh, Dw) stride within the unrolled K/V grid; also the LR→HR upsample factor.
scale: float, # 1 / sqrt(d_qk) scaling for the Q·K scores.
tile: int = 128, # Spatial tile size (output positions per tile)
v_chunk: int = 64 # Sub-divide d_v into chunks of this size when computing attn·V. None disables chunking.
) -> torch.Tensor: # [B, H, W, n_heads, d_v] attended features.
"""Neighborhood attention in pure torch via F.unfold + per-tile slicing.
K and V are passed at LR resolution and upsampled (nearest-exact) per-tile only
for the slice the unfold needs. Avoids the [B, n*d, H, W] HR allocations for K
(512 MB) and V (2 GB) at tex_1024 fp16. Spatial tiling bounds the per-tile
F.unfold blob; `v_chunk` further slices d_v so attn·V is computed in C-sized
chunks (attn is reused, computed once from Q/K).
"""
B, H, W, n, d_qk = q.shape
d_v = v_lr.shape[-1]
Kh, Kw = kernel_size
Dh, Dw = dilation
pad_h, pad_w = (Kh // 2) * Dh, (Kw // 2) * Dw
q_ = q.permute(0, 3, 4, 1, 2).contiguous() # [B, n, d_qk, H, W]
out = torch.empty((B, n, d_v, H, W), device=q.device, dtype=q.dtype)
th = min(tile, H) if tile else H
tw = min(tile, W) if tile else W
chunk = v_chunk if (v_chunk and v_chunk < d_v) else d_v
for h0 in range(0, H, th):
for w0 in range(0, W, tw):
h1, w1 = min(h0 + th, H), min(w0 + tw, W)
t_h, t_w = h1 - h0, w1 - w0
# Padded HR region the unfold needs (kernel span = (K-1)*D + 1).
h_src_start = max(0, h0 - pad_h)
h_src_end = min(H, h1 + pad_h)
w_src_start = max(0, w0 - pad_w)
w_src_end = min(W, w1 + pad_w)
pad_top = max(0, pad_h - h0)
pad_bot = max(0, (h1 + pad_h) - H)
pad_lft = max(0, pad_w - w0)
pad_rgt = max(0, (w1 + pad_w) - W)
# Upsample only the tile region from k_lr / v_lr.
k_tile = upsample_lr_slice(k_lr, Dh, Dw,
(h_src_start, h_src_end),
(w_src_start, w_src_end))
v_tile = upsample_lr_slice(v_lr, Dh, Dw,
(h_src_start, h_src_end),
(w_src_start, w_src_end))
if pad_top or pad_bot or pad_lft or pad_rgt:
k_tile = F.pad(k_tile, [pad_lft, pad_rgt, pad_top, pad_bot])
v_tile = F.pad(v_tile, [pad_lft, pad_rgt, pad_top, pad_bot])
# Q·K → attention weights (small: KK=81 per output position).
KK = Kh * Kw
k_w = F.unfold(k_tile, kernel_size=(Kh, Kw), dilation=(Dh, Dw), padding=0)
k_w = k_w.view(B, n, d_qk, KK, t_h * t_w).permute(0, 1, 4, 3, 2) # [B, n, t, KK, d_qk]
q_tile = q_[:, :, :, h0:h1, w0:w1].permute(0, 1, 3, 4, 2).reshape(B, n, t_h * t_w, 1, d_qk)
scores = torch.matmul(q_tile, k_w.transpose(-1, -2)) * scale
attn = scores.softmax(dim=-1)
del k_w, scores, q_tile, k_tile
# attn · V, chunked over d_v.
for c0 in range(0, d_v, chunk):
c1 = min(c0 + chunk, d_v)
v_w = F.unfold(v_tile[:, c0:c1], kernel_size=(Kh, Kw),dilation=(Dh, Dw), padding=0) # [B*n, (c1-c0)*KK, t]
v_w = v_w.view(B, n, c1 - c0, KK, t_h * t_w).permute(0, 1, 4, 3, 2)
out_chunk = torch.matmul(attn, v_w).squeeze(-2) # [B, n, t, c1-c0]
out_chunk = out_chunk.view(B, n, t_h, t_w, c1 - c0).permute(0, 1, 4, 2, 3)
out[:, :, c0:c1, h0:h1, w0:w1] = out_chunk
del v_w, out_chunk
del attn, v_tile
return out.permute(0, 3, 4, 1, 2).contiguous() # [B, H, W, n, d_v]
class CrossAttention(nn.Module):
"""Window-restricted cross-attention. No learnable parameters; the model's
capacity lives entirely in the ImageEncoder convs."""
def __init__(self, dim: int, num_heads: int, kernel_size: Tuple[int, int] = (9, 9)):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.num_heads = num_heads
self.kernel_size = kernel_size
self.scale = (dim // num_heads) ** -0.5
@staticmethod
def _split_heads_lr(x: torch.Tensor, num_heads: int) -> torch.Tensor:
"""[B, n*d, h, w] -> [B, h, w, n, d] at the input resolution (no upsample)."""
B, C, H, W = x.shape
return x.view(B, num_heads, C // num_heads, H, W).permute(0, 3, 4, 1, 2).contiguous()
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
# q is [B, C, Hq, Wq] at HR; k and v are at LR (Hk, Wk). We KEEP k and v at LR
# na2d_pure upsamples only the tile slice it needs.
hq, wq = q.shape[-2:]
hk, wk = k.shape[-2:]
dilation = (hq // hk, wq // wk)
B, C, _, _ = q.shape
q = q.view(B, self.num_heads, C // self.num_heads, hq, wq).permute(0, 3, 4, 1, 2).contiguous()
k_lr = self._split_heads_lr(k, self.num_heads).to(q.dtype)
v_lr = self._split_heads_lr(v, self.num_heads).to(q.dtype)
out = na2d_pure(q, k_lr, v_lr, self.kernel_size, dilation, self.scale)
# [B, H, W, n, d] -> [B, n*d, H, W]
return out.permute(0, 3, 4, 1, 2).contiguous().view(B, -1, hq, wq)
# RoPE positional embedding
def rope_rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
class RoPE(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, base: float = 100.0):
super().__init__()
assert embed_dim % (4 * num_heads) == 0
self.num_heads = num_heads
self.D_head = embed_dim // num_heads
self.base = base
self.register_buffer("periods", torch.empty(self.D_head // 4), persistent=True) # loaded from the checkpoint
self._cached_key = None
self._cached_cos_sin = None
def _cos_sin(self, H: int, W: int, dtype: torch.dtype):
"""cos/sin only depend on (H, W) and the output dtype (periods are fixed
once loaded from the checkpoint), so cache them saves the meshgrid /
angle / cos / sin / tile / flatten on every forward."""
key = (H, W, dtype)
if self._cached_key == key and self._cached_cos_sin is not None:
return self._cached_cos_sin
device = self.periods.device
coords_h = torch.arange(0.5, H, device=device, dtype=torch.float32) / H
coords_w = torch.arange(0.5, W, device=device, dtype=torch.float32) / W
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # [H, W, 2]
coords = coords.flatten(0, 1) * 2.0 - 1.0 # [HW, 2]
angles = 2 * math.pi * coords[:, :, None] / self.periods.to(coords.dtype)[None, None, :] # [HW, 2, D//4]
angles = angles.flatten(1, 2).tile(2) # [HW, D]
cos = torch.cos(angles).to(dtype)
sin = torch.sin(angles).to(dtype)
self._cached_cos_sin = (cos, sin)
self._cached_key = key
return cos, sin
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, n*D_head, H, W]
B, C, H, W = x.shape
n = self.num_heads
D = C // n
x = x.view(B, n, D, H, W).permute(0, 1, 3, 4, 2).reshape(B, n, H * W, D)
cos, sin = self._cos_sin(H, W, x.dtype)
x = (x * cos) + (rope_rotate_half(x) * sin)
x = x.view(B, n, H, W, D).permute(0, 1, 4, 2, 3).reshape(B, n * D, H, W)
return x
# Image encoder
class EncBlock(nn.Module):
def __init__(self, channels: int, kernel_size: int, num_groups: int = 8):
super().__init__()
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=channels)
self.conv1 = nn.Conv2d(channels, channels, kernel_size=kernel_size,
padding=kernel_size // 2, padding_mode="reflect", bias=True)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=kernel_size,
padding=kernel_size // 2, padding_mode="reflect", bias=True)
self.activation_fn = nn.SiLU()
def forward(self, x):
x = self.norm1(x)
x = self.activation_fn(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.activation_fn(x)
x = self.conv2(x)
return x # no skip connection
def _encoder(in_dim: int, hidden_dim: int, kernel_size: int = 1, ks_res: int = 1, num_layers: int = 2) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(in_dim, hidden_dim, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode="reflect", bias=True),
*[EncBlock(hidden_dim, kernel_size=ks_res) for _ in range(num_layers)],
)
class ImageEncoder(nn.Module):
"""Two parallel conv stacks (1x1 + 3x3) producing dim/2 channels each, then concat,
spatial average-pool to target size, RoPE-embed positions."""
def __init__(self, in_channels: int = 3, out_channels: int = 256,
heads_rope: int = 4, rope_base: float = 100.0, img_layers: int = 2):
super().__init__()
half = out_channels // 2
self.encoder = _encoder(in_channels, half, kernel_size=1, ks_res=1, num_layers=img_layers)
self.sem_encoder = _encoder(in_channels, half, kernel_size=3, ks_res=3, num_layers=img_layers)
self.rope = RoPE(embed_dim=out_channels, num_heads=heads_rope, base=rope_base)
def forward(self, x: torch.Tensor, output_size: Tuple[int, int]) -> torch.Tensor:
# Avoid running the conv stacks on >4× the target resolution.
out_h, out_w = output_size
if x.shape[-2] > 4 * out_h or x.shape[-1] > 4 * out_w:
x = F.interpolate(x, size=(min(x.shape[-2], 4 * out_h),
min(x.shape[-1], 4 * out_w)),
mode="bilinear", align_corners=False)
x = torch.cat([self.encoder(x), self.sem_encoder(x)], dim=1)
x = F.adaptive_avg_pool2d(x, output_size=output_size)
x = self.rope(x)
return x
# Top-level NAF model.
class NAF(nn.Module):
"""NAF feature upsampler."""
def __init__(
self, dim: int = 256, # internal channel dimension of the ImageEncoder
heads_attn: int = 4, # attention heads in the windowed cross-attn
heads_rope: int = 4, # heads for RoPE position encoding (must divide dim)
kernel_size: int = 9, # square kernel for the neighborhood attention window
rope_base: float = 100.0, # base for RoPE frequency periods
img_layers: int = 2 # number of EncBlocks in each conv stack
):
super().__init__()
self.image_encoder = ImageEncoder(in_channels=3, out_channels=dim, heads_rope=heads_rope, rope_base=rope_base, img_layers=img_layers)
self.upsampler = CrossAttention(dim=dim, num_heads=heads_attn, kernel_size=(kernel_size, kernel_size))
def forward(
self,
image: torch.Tensor, # [B, 3, H_img, W_img] in [0, 1].
features: torch.Tensor, # [B, C, H_feat, W_feat] low-resolution features (any C).
output_size: Tuple[int, int] # (H_out, W_out) target spatial resolution for the upsampled features.
) -> torch.Tensor: # [B, C, H_out, W_out] upsampled features.
"""Upsample low-res feature map to output_size, guided by the image."""
q = self.image_encoder(image, output_size=output_size)
k = F.adaptive_avg_pool2d(q, output_size=features.shape[-2:])
return self.upsampler(q, k, features)
def build_naf_from_state_dict(state_dict: dict) -> NAF:
"""Instantiate NAF with the default hyperparams and load the given state_dict.
The published NAF release uses the default constructor (dim=256, heads_attn=4,
heads_rope=4, kernel_size=9, rope_base=100, img_layers=2)."""
model = NAF()
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if unexpected:
raise ValueError(f"Unexpected keys in NAF state_dict: {sorted(unexpected)[:8]}...")
return model

View File

@ -75,13 +75,9 @@ def sparse_conv3d_forward(self, x):
class LayerNorm32(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dtype = x.dtype
x = x.to(torch.float32)
w = self.weight.to(torch.float32) if self.weight is not None else None
b = self.bias.to(torch.float32) if self.bias is not None else None
o = F.layer_norm(x, self.normalized_shape, w, b, self.eps)
return o.to(x_dtype)
w = self.weight.to(x.dtype) if self.weight is not None else None
b = self.bias.to(x.dtype) if self.bias is not None else None
return F.layer_norm(x, self.normalized_shape, w, b, self.eps)
class SparseConvNeXtBlock3d(nn.Module):
def __init__(
@ -204,7 +200,6 @@ class SparseResBlockC2S3d(nn.Module):
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
self.conv1 = SparseConv3d(channels, self.out_channels * 8, 3)
self.conv2 = SparseConv3d(self.out_channels, self.out_channels, 3)
self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1))
if pred_subdiv:
self.to_subdiv = SparseLinear(channels, 8)
self.updown = SparseChannel2Spatial(2)
@ -215,15 +210,16 @@ class SparseResBlockC2S3d(nn.Module):
x = x.to(dtype)
subdiv = self.to_subdiv(x)
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
h = h.replace(F.silu(h.feats, inplace=True))
h = self.conv1(h)
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
h = self.updown(h, subdiv_binarized)
x = self.updown(x, subdiv_binarized)
h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = h.replace(F.silu(h.feats, inplace=True))
h = self.conv2(h)
h = h + self.skip_connection(x)
skip_repeat = self.out_channels // (self.channels // 8)
h.feats.view(h.feats.shape[0], x.feats.shape[1], skip_repeat).add_(x.feats.unsqueeze(-1))
if self.pred_subdiv:
return h, subdiv
else:
@ -1211,13 +1207,12 @@ def flexible_dual_grid_to_mesh(
edge_neighbor_voxel = coords.reshape(N, 1, 1, 3) + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset # (N, 3, 4, 3)
connected_voxel = edge_neighbor_voxel[intersected_flag] # (M, 4, 3)
M = connected_voxel.shape[0]
# flatten connected voxel coords and lookup
conn_flat_b = torch.zeros((M * 4,), dtype=torch.long, device=coords.device)
conn_x = connected_voxel.reshape(-1, 3)[:, 0].to(torch.int32)
conn_y = connected_voxel.reshape(-1, 3)[:, 1].to(torch.int32)
conn_z = connected_voxel.reshape(-1, 3)[:, 2].to(torch.int32)
# flatten connected voxel coords and lookup. In-place to avoid extra memory allocation.
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
conn_flat = conn_flat_b * (W * H * D) + conn_x * (H * D) + conn_y * D + conn_z
cv = connected_voxel.reshape(-1, 3)
conn_flat = cv[:, 0].long() * (H * D)
conn_flat.add_(cv[:, 1].long() * D)
conn_flat.add_(cv[:, 2].long())
conn_indices = torch_hashmap.lookup_flat(conn_flat).reshape(M, 4).int()
connected_voxel_valid = (conn_indices != 0xffffffff).all(dim=1)

View File

@ -113,13 +113,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
return unet_config
if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
def _detect_proj(sub_prefix: str, name: str):
key = '{}{}.blocks.0.cross_attn.proj_linear.weight'.format(key_prefix, sub_prefix)
if key in state_dict_keys:
unet_config["image_attn_mode_{}".format(name)] = "proj"
unet_config["proj_in_channels_{}".format(name)] = int(state_dict[key].shape[1])
if '{}img2shape.blocks.0.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or \
'{}img2shape.blocks.0.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
unet_config = {}
unet_config["image_model"] = "trellis2"
unet_config["init_txt_model"] = False
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
unet_config["init_txt_model"] = True
unet_config["init_txt_model"] = (
'{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or
'{}shape2txt.blocks.29.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys
)
unet_config["resolution"] = 64
if metadata is not None:
@ -127,14 +135,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
unet_config["resolution"] = 32
unet_config["num_heads"] = 12
_detect_proj("img2shape", "shape")
_detect_proj("shape2txt", "texture")
_detect_proj("structure_model", "structure")
return unet_config
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: # trellis2 texture
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or \
'{}shape2txt.blocks.29.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: # trellis2 texture
unet_config = {}
unet_config["image_model"] = "trellis2"
unet_config["resolution"] = 64
unet_config["num_heads"] = 12
unet_config["txt_only"] = True
_detect_proj("shape2txt", "texture")
return unet_config
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit

View File

@ -1325,6 +1325,7 @@ class Trellis2(supported_models_base.BASE):
sampling_settings = {
"shift": 3.0,
"multiplier": 1.0
}
memory_usage_factor = 3.5

View File

@ -276,13 +276,30 @@ class RescaleCFG:
CATEGORY = "advanced/model"
def patch(self, model, multiplier):
model_sampling = model.get_model_object("model_sampling")
is_x0_space = not isinstance(model_sampling, comfy.model_sampling.EPS)
def rescale_cfg(args):
x_orig = args["input"]
cond_scale = args["cond_scale"]
if is_x0_space:
# Flow-matching / X0 models: cond_denoised/uncond_denoised are x_0 estimates,
# so the eps↔v conversion below would be wrong. Rescale directly in x_0 space.
x_0_cond = args["cond_denoised"]
x_0_uncond = args["uncond_denoised"]
x_0_cfg = x_0_uncond + cond_scale * (x_0_cond - x_0_uncond)
dims = tuple(range(1, x_0_cond.ndim))
ro_pos = x_0_cond.std(dim=dims, keepdim=True)
ro_cfg = x_0_cfg.std(dim=dims, keepdim=True).clamp(min=1e-8)
x_0_rescaled = x_0_cfg * (ro_pos / ro_cfg)
x_0_final = multiplier * x_0_rescaled + (1.0 - multiplier) * x_0_cfg
return x_orig - x_0_final
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
sigma = args["sigma"]
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
x_orig = args["input"]
#rescale cfg has to be done on v-pred model output
x = x_orig / (sigma * sigma + 1.0)

View File

@ -1,13 +1,60 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types, io
from comfy.ldm.trellis2.vae import SparseTensor
from comfy.ldm.trellis2.model import _build_proj_transform_matrix, _project_points_to_image
from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
import comfy.model_management
import comfy.utils
import folder_paths
from PIL import Image
import logging
import numpy as np
import math
import torch
ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
Pixal3DProjPack = io.Custom("PIXAL3D_PROJ_PACK")
NAFModel = io.Custom("NAF_MODEL")
# Pixal3D trains in a 90°-X-rotated grid frame (F_p). We un-rotate decoder outputs for
# user-facing previews/meshes, then re-rotate before feeding coords back to the shape DiT.
def _pixal3d_unrotate_voxel_data(data: torch.Tensor) -> torch.Tensor:
if data.ndim == 4:
return data.flip(-1).permute(0, 1, 3, 2).contiguous()
if data.ndim == 5:
return data.flip(-1).permute(0, 1, 2, 4, 3).contiguous()
raise ValueError(f"unexpected voxel shape {tuple(data.shape)}")
def _pixal3d_rerotate_voxel_data(data: torch.Tensor) -> torch.Tensor:
if data.ndim == 4:
return data.permute(0, 1, 3, 2).flip(-1).contiguous()
if data.ndim == 5:
return data.permute(0, 1, 2, 4, 3).flip(-1).contiguous()
raise ValueError(f"unexpected voxel shape {tuple(data.shape)}")
def _pixal3d_unrotate_vertices(vertices: torch.Tensor) -> torch.Tensor:
if vertices.numel() == 0:
return vertices
x, y, z = vertices.unbind(-1)
return torch.stack([-x, y, -z], dim=-1).contiguous()
def _pixal3d_unrotate_sparse_coords(coords: torch.Tensor, resolution: int) -> torch.Tensor:
if coords.numel() == 0:
return coords
R1 = resolution - 1
if coords.shape[-1] == 4:
b, i, j, k = coords.unbind(-1)
return torch.stack([b, R1 - i, j, R1 - k], dim=-1).contiguous()
if coords.shape[-1] == 3:
i, j, k = coords.unbind(-1)
return torch.stack([R1 - i, j, R1 - k], dim=-1).contiguous()
raise ValueError(f"unexpected coord shape {tuple(coords.shape)}")
def prepare_trellis_vae_for_decode(vae, sample_shape):
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
@ -163,6 +210,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
trellis_vae = vae.first_stage_model
coord_counts = samples.get("coord_counts")
pixal3d_mode = samples.get("model_options", {}).get("proj_feat_pack") is not None
samples = samples["samples"]
if coord_counts is None:
@ -188,6 +236,10 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
coords_list = [stage_tensor.coords for stage_tensor in stage_tensors]
subs.append(SparseTensor.from_tensor_list(feats_list, coords_list))
if pixal3d_mode:
for m in mesh:
m.vertices = _pixal3d_unrotate_vertices(m.vertices)
face_list = [m.faces for m in mesh]
vert_list = [m.vertices for m in mesh]
if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list):
@ -224,6 +276,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
trellis_vae = vae.first_stage_model
coord_counts = samples.get("coord_counts")
pixal3d_mode = samples.get("model_options", {}).get("proj_feat_pack") is not None
samples = samples["samples"]
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
@ -237,7 +290,17 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
color_feats = voxel.feats[:, :3]
voxel_coords = voxel.coords#[:, 1:]
voxel = Types.VOXEL(voxel_coords, color_feats, 1024)
if voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
spatial = voxel_coords[:, -3:] if voxel_coords.shape[-1] == 4 else voxel_coords
max_idx = int(spatial.max().item()) + 1
tex_resolution = next((r for r in (256, 512, 1024, 1536, 2048) if r >= max_idx), max_idx)
else:
tex_resolution = 1024
if pixal3d_mode:
voxel_coords = _pixal3d_unrotate_sparse_coords(voxel_coords, resolution=tex_resolution)
voxel = Types.VOXEL(voxel_coords, color_feats, tex_resolution)
return IO.NodeOutput(voxel)
class VaeDecodeStructureTrellis2(IO.ComfyNode):
@ -274,7 +337,10 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
if current_res != resolution:
ratio = current_res // resolution
decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
out = Types.VOXEL(decoded.squeeze(1).float())
voxel_data = decoded.squeeze(1).float()
if samples.get("model_options", {}).get("proj_feat_pack") is not None:
voxel_data = _pixal3d_unrotate_voxel_data(voxel_data)
out = Types.VOXEL(voxel_data)
return IO.NodeOutput(out)
class Trellis2UpsampleCascade(IO.ComfyNode):
@ -540,7 +606,6 @@ class Trellis2Conditioning(IO.ComfyNode):
cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2))
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
else:
import logging
logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.")
cropped_np = rgba_np.astype(np.float32) / 255.0
@ -587,7 +652,12 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
"Shape structure input. Accepts either a voxel structure "
"or upsampled voxel coordinates from a previous cascade stage."
)
)
),
Pixal3DProjPack.Input(
"proj_feat_pack",
optional=True,
tooltip="Pixal3D pixel-aligned projection pack from Pixal3DConditioning. Leave empty for vanilla Trellis2.",
),
],
outputs=[
IO.Latent.Output(),
@ -595,21 +665,26 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
)
@classmethod
def execute(cls, voxel):
# to accept the upscaled coords
def execute(cls, voxel, proj_feat_pack=None):
is_512_pass = False
coord_resolution = None
upsampled = hasattr(voxel, "upsampled")
if upsampled:
if hasattr(voxel, "resolutions") and voxel.resolutions is not None:
coord_resolution = int(voxel.resolutions[0].item()) // 16
voxel = voxel.data
if not upsampled:
decoded = voxel.data.unsqueeze(1)
voxel_data = voxel.data
if proj_feat_pack is not None:
voxel_data = _pixal3d_rerotate_voxel_data(voxel_data)
decoded = voxel_data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
is_512_pass = True
coord_resolution = int(decoded.shape[-1])
else:
coords = voxel.int()
is_512_pass = False
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
in_channels = 32
@ -620,8 +695,13 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
generation_mode = "shape_generation_512"
else:
generation_mode = "shape_generation"
model_options = {"generation_mode": generation_mode, "coords": coords, "coord_counts": counts}
if coord_resolution is not None:
model_options["coord_resolution"] = coord_resolution
if proj_feat_pack is not None:
model_options["proj_feat_pack"] = proj_feat_pack
return IO.NodeOutput({"samples": latent, "coords": coords, "coord_counts": counts, "type": "trellis2",
"model_options": {"generation_mode": generation_mode, "coords": coords, "coord_counts": counts}})
"model_options": model_options})
class EmptyTrellis2LatentTexture(IO.ComfyNode):
@classmethod
@ -638,6 +718,11 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
)
),
IO.Latent.Input("shape_latent"),
Pixal3DProjPack.Input(
"proj_feat_pack",
optional=True,
tooltip="Pixal3D pixel-aligned projection pack from Pixal3DConditioning. Leave empty for vanilla Trellis2.",
),
],
outputs=[
IO.Latent.Output(),
@ -645,15 +730,22 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
)
@classmethod
def execute(cls, voxel, shape_latent):
def execute(cls, voxel, shape_latent, proj_feat_pack=None):
channels = 32
coord_resolution = None
upsampled = hasattr(voxel, "upsampled")
if upsampled:
if hasattr(voxel, "resolutions") and voxel.resolutions is not None:
coord_resolution = int(voxel.resolutions[0].item()) // 16
voxel = voxel.data
if not upsampled:
decoded = voxel.data.unsqueeze(1)
voxel_data = voxel.data
if proj_feat_pack is not None:
voxel_data = _pixal3d_rerotate_voxel_data(voxel_data)
decoded = voxel_data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
coord_resolution = int(decoded.shape[-1])
else:
coords = voxel.int()
@ -664,9 +756,13 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
latent = torch.zeros(batch_size, channels, max_tokens, 1)
model_options = {"generation_mode": "texture_generation", "coords": coords, "coord_counts": counts, "shape_slat": shape_latent}
if coord_resolution is not None:
model_options["coord_resolution"] = coord_resolution
if proj_feat_pack is not None:
model_options["proj_feat_pack"] = proj_feat_pack
return IO.NodeOutput({"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts,
"model_options": {"generation_mode": "texture_generation",
"coords": coords, "coord_counts": counts, "shape_slat": shape_latent}})
"model_options": model_options})
class EmptyTrellis2LatentStructure(IO.ComfyNode):
@ -677,27 +773,441 @@ class EmptyTrellis2LatentStructure(IO.ComfyNode):
category="latent/3d",
inputs=[
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
Pixal3DProjPack.Input(
"proj_feat_pack",
optional=True,
tooltip="Pixal3D pixel-aligned projection pack. Leave empty for vanilla Trellis2.",
),
],
outputs=[
IO.Latent.Output(),
]
)
@classmethod
def execute(cls, batch_size):
in_channels = 8
def execute(cls, batch_size, proj_feat_pack=None):
# Trellis2.forward slices x[:, :8] and pads out to 32; KSampler residual math
# needs the empty latent to match latent_format (32-channel).
in_channels = 32
resolution = 16
latent = torch.zeros(batch_size, in_channels, resolution, resolution, resolution)
output = {
"samples": latent,
"type": "trellis2",
}
if proj_feat_pack is not None:
output["model_options"] = {"proj_feat_pack": proj_feat_pack}
return IO.NodeOutput(output)
def _dinov3_patches_to_2d(tokens, image_size, patch_size=16):
h_p = w_p = image_size // patch_size
n_patches = h_p * w_p
n_reg = tokens.shape[1] - 1 - n_patches
if n_reg < 0 or tokens.shape[1] != 1 + n_reg + n_patches:
raise ValueError(
f"_dinov3_patches_to_2d: got {tokens.shape[1]} tokens, expected "
f"1 (CLS) + N_reg + {h_p}*{w_p}={n_patches} patches at image_size={image_size}, "
f"patch_size={patch_size}. Inferred N_reg={n_reg} which is invalid."
)
start = 1 + n_reg
patches = tokens[:, start:start + n_patches]
return patches.transpose(1, 2).reshape(tokens.shape[0], -1, h_p, w_p).contiguous()
def _fov_from_moge_intrinsics(moge_intrinsics: torch.Tensor) -> float:
fx = moge_intrinsics[..., 0, 0].float()
fov = 2.0 * torch.atan(0.5 / fx.clamp(min=1e-4))
return float(fov.mean().item())
def _run_dinov3_with_patches(model, cropped_pil, image_size):
# Pixal3D's cross-attn was trained against CLS + registers only (~5 tokens), not the
# full patch grid. The patch grid goes to the proj branch via patches_2d.
model_internal = model.model
torch_device = comfy.model_management.get_torch_device()
resized = cropped_pil.resize((image_size, image_size), Image.Resampling.LANCZOS)
img_np = np.array(resized).astype(np.float32) / 255.0
img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device)
img_t = (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
model_internal.image_size = image_size
tokens = model_internal(img_t, skip_norm_elementwise=True)[0]
patches = _dinov3_patches_to_2d(tokens, image_size)
h_p = w_p = image_size // 16
n_reg = tokens.shape[1] - 1 - h_p * w_p
global_tokens = tokens[:, :1 + n_reg]
return {"tokens": global_tokens, "patches_2d": patches}
def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
pil_img = Image.fromarray(img_np)
pil_mask = Image.fromarray(mask_np)
max_size = max(pil_img.size)
scale = min(1.0, max_image_size / max_size)
if scale < 1.0:
new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale)
pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS)
pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST)
scene_size = (pil_img.width, pil_img.height)
rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8)
rgba_np[:, :, :3] = np.array(pil_img)
rgba_np[:, :, 3] = np.array(pil_mask)
alpha = rgba_np[:, :, 3]
bbox_coords = np.argwhere(alpha > 0.8 * 255)
if len(bbox_coords) > 0:
y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1])
y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1])
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
# Upstream pads the bbox by 10% — encoders were trained with that breathing room.
size = max(y_max - y_min, x_max - x_min)
size = int(size * 1.1)
half = size // 2
crop_x1 = int(center_x - half)
crop_y1 = int(center_y - half)
crop_x2 = crop_x1 + 2 * half
crop_y2 = crop_y1 + 2 * half
crop_bbox = (crop_x1, crop_y1, crop_x2, crop_y2)
rgba_pil = Image.fromarray(rgba_np)
cropped_rgba = rgba_pil.crop(crop_bbox)
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
else:
logging.warning("Mask for the image is empty. Pixal3D requires a clean foreground mask.")
cropped_np = rgba_np.astype(np.float32) / 255.0
crop_bbox = (0, 0, scene_size[0], scene_size[1])
fg = cropped_np[:, :, :3]
alpha_float = cropped_np[:, :, 3:4]
composite_np = fg * alpha_float
composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
return Image.fromarray(composite_uint8), crop_bbox, scene_size
class Pixal3DConditioning(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Pixal3DConditioning",
category="conditioning/video_models",
inputs=[
IO.ClipVision.Input("clip_vision_model", tooltip="DINOv3 ViT-L/16 ClipVision."),
IO.Image.Input("image"),
IO.Mask.Input("mask"),
IO.Float.Input(
"camera_angle_x", default=0.2, min=0.0175, max=2.9671, step=0.001,
tooltip="Horizontal FOV in radians (upstream demo default 0.2). "
"Overridden by moge_geometry if connected.",
),
IO.Float.Input(
"mesh_scale", default=1.0, min=0.1, max=4.0, step=0.01,
tooltip="Mesh scale; 1.0 means unit cube.",
),
IO.Float.Input(
"distance_override", default=0.0, min=0.0, max=10.0, step=0.001,
tooltip="Override camera distance directly. 0 = auto-derive from FOV.",
),
io.Custom("MOGE_GEOMETRY").Input(
"moge_geometry",
optional=True,
tooltip="If connected, camera_angle_x is recovered from MoGe.",
),
NAFModel.Input(
"naf_model",
optional=True,
tooltip="Optional NAF feature upsampler. Required for shape/texture stages "
"to match upstream's trained feature distribution.",
),
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
IO.Conditioning.Output(display_name="negative"),
Pixal3DProjPack.Output(display_name="proj_feat_pack"),
],
)
@classmethod
def execute(cls, clip_vision_model, image, mask, camera_angle_x, mesh_scale,
distance_override=0.0,
moge_geometry=None, naf_model=None) -> IO.NodeOutput:
if image.ndim == 3:
image = image.unsqueeze(0)
if mask.ndim == 2:
mask = mask.unsqueeze(0)
batch_size = image.shape[0]
if mask.shape[0] == 1 and batch_size > 1:
mask = mask.expand(batch_size, -1, -1)
elif mask.shape[0] != batch_size:
raise ValueError(f"Pixal3DConditioning mask batch {mask.shape[0]} != image batch {batch_size}")
if moge_geometry is not None and "intrinsics" in moge_geometry:
camera_angle_x = _fov_from_moge_intrinsics(moge_geometry["intrinsics"])
device = comfy.model_management.intermediate_device()
cond_512_list, cond_1024_list = [], []
patches_512_list, patches_1024_list = [], []
cropped_pil_list = []
crop_bbox_list, scene_size_list = [], []
torch_device = comfy.model_management.get_torch_device()
for b in range(batch_size):
item_image = image[b]
item_mask = mask[b] if mask.size(0) > 1 else mask[0]
cropped_pil, crop_bbox, scene_size = _crop_image_with_mask(
item_image, item_mask, max_image_size=1024)
crop_bbox_list.append(crop_bbox)
scene_size_list.append(scene_size)
cropped_pil_list.append(cropped_pil)
cond_512 = _run_dinov3_with_patches(clip_vision_model, cropped_pil, 512)
cond_1024 = _run_dinov3_with_patches(clip_vision_model, cropped_pil, 1024)
cond_512_list.append(cond_512["tokens"].to(device))
cond_1024_list.append(cond_1024["tokens"].to(device))
patches_512_list.append(cond_512["patches_2d"].to(device))
patches_1024_list.append(cond_1024["patches_2d"].to(device))
global_512 = torch.cat(cond_512_list, dim=0)
global_1024 = torch.cat(cond_1024_list, dim=0)
fm_512_dino = torch.cat(patches_512_list, dim=0)
fm_1024_dino = torch.cat(patches_1024_list, dim=0)
# Upstream samples the LR DINO grid AND the NAF HR grid separately at projected
# 3D points, then cats sampled features along channels. Back-projection (in model.py)
# mirrors that — here we just stash LR + optional HR per stage.
# NAF targets per stage: shape_512=512, shape_1024=512, tex_1024=1024.
def _naf_hr(lr_feat, image_pil_list, image_size, naf_target):
if naf_model is None or naf_target is None:
return None
# Run NAF in the input feature dtype (typically fp16 since DINO/ClipVision
# loads that way). The previous .float() cast doubled NAF memory by forcing
# full fp32 — at tex_1024/target=1024 that's ~10 GB on its own. Model
# weights need to match input dtype since PyTorch conv ops error out on
# mixed fp16-input/fp32-weight.
target_dtype = lr_feat.dtype
if next(naf_model.parameters()).dtype != target_dtype:
naf_model.to(dtype=target_dtype)
imgs = torch.stack([
torch.from_numpy(
np.array(p.resize((image_size, image_size), Image.Resampling.LANCZOS))
.astype(np.float32) / 255.0
).permute(2, 0, 1)
for p in image_pil_list
], dim=0).to(torch_device).to(target_dtype)
hr = naf_model(imgs, lr_feat.to(torch_device).to(target_dtype), naf_target)
return hr.to(device)
hr_shape_512 = _naf_hr(fm_512_dino, cropped_pil_list, 512, (512, 512))
hr_shape_1024 = _naf_hr(fm_1024_dino, cropped_pil_list, 1024, (512, 512))
hr_tex_1024 = _naf_hr(fm_1024_dino, cropped_pil_list, 1024, (1024, 1024))
# distance_from_fov: grid_point (-1, 0, 0) projects to pixel (0, image_resolution-1).
camera_angle_x = float(camera_angle_x)
if distance_override > 0:
distance = float(distance_override)
else:
distance = 0.5 / math.tan(camera_angle_x / 2.0) / float(mesh_scale)
cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=device, dtype=torch.float32)
dist_t = torch.tensor([distance] * batch_size, device=device, dtype=torch.float32)
scale_t = torch.tensor([float(mesh_scale)] * batch_size, device=device, dtype=torch.float32)
T = _build_proj_transform_matrix(dist_t, batch_size, device=device, dtype=torch.float32)
proj_pack = {
"stages": {
"ss": {"feature_map": fm_512_dino, "feature_map_hr": None, "image_resolution": 512},
"shape_512": {"feature_map": fm_512_dino, "feature_map_hr": hr_shape_512, "image_resolution": 512},
"shape_1024": {"feature_map": fm_1024_dino, "feature_map_hr": hr_shape_1024,"image_resolution": 1024},
"tex_1024": {"feature_map": fm_1024_dino, "feature_map_hr": hr_tex_1024, "image_resolution": 1024},
},
"transform_matrix": T,
"camera_angle_x": cam_angle_t,
"mesh_scale": scale_t,
"distance": dist_t,
"patch_size": 16,
"crop_bboxes": crop_bbox_list,
"scene_sizes": scene_size_list,
}
# global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024
# (Trellis2.forward swaps context↔embeds for non-structure HR stages).
neg_global = torch.zeros_like(global_512)
neg_embeds = torch.zeros_like(global_1024)
positive = [[global_512, {"embeds": global_1024}]]
negative = [[neg_global, {"embeds": neg_embeds}]]
return IO.NodeOutput(positive, negative, proj_pack)
def _project_vertices_to_image_uv(vertices_world, transform_matrix, camera_angle_x, image_resolution):
points = vertices_world.unsqueeze(0).float()
T = transform_matrix.unsqueeze(0).float() if transform_matrix.ndim == 2 else transform_matrix.float()
cam = camera_angle_x.unsqueeze(0) if camera_angle_x.ndim == 0 else camera_angle_x
uv_pix, depth, valid = _project_points_to_image(points, T, cam.float(), image_resolution)
uv = uv_pix.squeeze(0) / image_resolution
return uv, depth.squeeze(0), valid.squeeze(0)
def _crop_uv_to_scene_pixels(uv_crop, crop_bbox, scene_image_size):
crop_x1, crop_y1, crop_x2, crop_y2 = crop_bbox
crop_w = max(1, crop_x2 - crop_x1)
crop_h = max(1, crop_y2 - crop_y1)
px = uv_crop[:, 0] * crop_w + crop_x1
py = uv_crop[:, 1] * crop_h + crop_y1
W, H = scene_image_size
return torch.stack([px.clamp(0, W - 1), py.clamp(0, H - 1)], dim=-1)
class Pixal3DAlignObject(IO.ComfyNode):
"""Pixal3D paper §3.3 Global Alignment for a single object.
Solves (scale, translation) aligning the mesh to MoGe's per-pixel point map. Requires
MoGe to have been computed on the same resized scene image as Pixal3DConditioning."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Pixal3DAlignObject",
category="latent/3d",
inputs=[
IO.Mesh.Input("mesh"),
Pixal3DProjPack.Input("proj_feat_pack", tooltip="The proj pack produced by Pixal3DConditioning for this object."),
io.Custom("MOGE_GEOMETRY").Input("moge_geometry", tooltip="MoGe geometry computed on the original scene image."),
IO.Mask.Input(
"object_mask",
optional=True,
tooltip="Optional per-object scene-space mask. If connected, only vertices whose projected pixel falls inside the mask contribute to the alignment solve.",
),
IO.Int.Input(
"batch_index",
default=0, min=0, max=1024,
tooltip="Which batch slot of the proj_feat_pack/MoGe geometry corresponds to this object.",
),
],
outputs=[
IO.Mesh.Output("aligned_mesh"),
IO.Float.Output(display_name="scale"),
],
)
@classmethod
def execute(cls, mesh, proj_feat_pack, moge_geometry, object_mask=None, batch_index=0) -> IO.NodeOutput:
vertices = mesh.vertices
faces = mesh.faces
if vertices.ndim == 3:
vertices_one = vertices[0]
faces_one = faces[0]
else:
vertices_one = vertices
faces_one = faces
T = proj_feat_pack["transform_matrix"][batch_index:batch_index + 1]
cam_angle = proj_feat_pack["camera_angle_x"][batch_index:batch_index + 1]
mesh_scale = proj_feat_pack["mesh_scale"][batch_index]
image_resolution = int(proj_feat_pack.get("image_resolution", 1024))
crop_bbox = proj_feat_pack["crop_bboxes"][batch_index]
pack_scene_size = proj_feat_pack.get("scene_sizes", [None] * (batch_index + 1))[batch_index]
moge_points = moge_geometry["points"]
moge_mask = moge_geometry["mask"]
if moge_points.ndim != 4:
raise ValueError(f"MoGe points expected [B, H, W, 3]; got {tuple(moge_points.shape)}")
scene_H, scene_W = moge_points.shape[1], moge_points.shape[2]
if pack_scene_size is not None and pack_scene_size != (scene_W, scene_H):
raise ValueError(
f"Pixal3DAlignObject: MoGe geometry was computed on a {scene_W}x{scene_H} image, "
f"but the proj_feat_pack's bbox lives in a {pack_scene_size[0]}x{pack_scene_size[1]} "
"image. Run MoGe on the same resized scene image Pixal3DConditioning used."
)
# Compose VaeDecodeShapeTrellis's R_y(180°) inverse with R_proj to map user mesh
# space to ProjGrid world: (X, Y, Z) -> (-X, Z, Y).
v = vertices_one.float()
verts_world = torch.stack([-v[..., 0], v[..., 2], v[..., 1]], dim=-1)
verts_world = verts_world / float(mesh_scale.item())
uv_crop, _depth, valid = _project_vertices_to_image_uv(
verts_world, T[0], cam_angle[0], image_resolution)
scene_pixels = _crop_uv_to_scene_pixels(uv_crop, crop_bbox, (scene_W, scene_H))
in_scene = ((scene_pixels[:, 0] >= 0) & (scene_pixels[:, 0] < scene_W) &
(scene_pixels[:, 1] >= 0) & (scene_pixels[:, 1] < scene_H))
sx = scene_pixels[:, 0].long().clamp(0, scene_W - 1)
sy = scene_pixels[:, 1].long().clamp(0, scene_H - 1)
moge_per_vertex = moge_points[batch_index, sy, sx]
moge_mask_per_vertex = moge_mask[batch_index, sy, sx]
keep = valid & in_scene & moge_mask_per_vertex
if object_mask is not None:
om = object_mask if object_mask.ndim == 2 else object_mask[batch_index]
keep = keep & (om[sy, sx] > 0.5)
finite = torch.isfinite(moge_per_vertex).all(dim=-1)
keep = keep & finite
kept = int(keep.sum().item())
if kept < 8:
scale = 1.0
aligned = vertices_one
else:
P = vertices_one[keep].float()
Q = moge_per_vertex[keep].float()
p_mean = P.mean(dim=0, keepdim=True)
q_mean = Q.mean(dim=0, keepdim=True)
P_c = P - p_mean
Q_c = Q - q_mean
num = (P_c * Q_c).sum()
den = (P_c * P_c).sum().clamp(min=1e-8)
scale = float((num / den).item())
if not (scale > 0):
# Negative scale would mirror the mesh; treat as a camera-convention mismatch.
logging.warning(
f"Pixal3DAlignObject: computed scale={scale:.4f} <= 0; "
"refusing to apply mirroring. Check camera convention alignment.")
scale = 1.0
aligned = vertices_one
else:
t = q_mean - scale * p_mean
aligned = scale * vertices_one + t
if vertices.ndim == 3:
aligned = aligned.unsqueeze(0)
out_mesh = Types.MESH(vertices=aligned, faces=faces)
else:
out_mesh = Types.MESH(vertices=aligned, faces=faces_one)
return IO.NodeOutput(out_mesh, float(scale))
class LoadNAFModel(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="LoadNAFModel",
display_name="Load NAF Model",
category="loaders",
inputs=[
IO.Combo.Input(
"naf_name",
options=folder_paths.get_filename_list("upscale_models"),
tooltip="NAF safetensors checkpoint (e.g. naf_release.safetensors).",
),
],
outputs=[NAFModel.Output(display_name="naf_model")],
)
@classmethod
def execute(cls, naf_name) -> IO.NodeOutput:
path = folder_paths.get_full_path_or_raise("upscale_models", naf_name)
sd = comfy.utils.load_torch_file(path, safe_load=True)
model = build_naf_from_state_dict(sd)
device = comfy.model_management.get_torch_device()
model = model.to(device).eval()
return IO.NodeOutput(model)
class Trellis2Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
Trellis2Conditioning,
Pixal3DConditioning,
Pixal3DAlignObject,
LoadNAFModel,
EmptyTrellis2ShapeLatent,
EmptyTrellis2LatentStructure,
EmptyTrellis2LatentTexture,