From 9e4794da5cdfdc7f6a7931613d3e430a28fc21f0 Mon Sep 17 00:00:00 2001 From: kijai Date: Fri, 22 May 2026 01:50:48 +0300 Subject: [PATCH] initial Pixal3D support --- comfy/ldm/trellis2/attention.py | 62 ++- comfy/ldm/trellis2/flexgemm.py | 51 ++- comfy/ldm/trellis2/model.py | 410 ++++++++++++++++++-- comfy/ldm/trellis2/naf/model.py | 301 +++++++++++++++ comfy/ldm/trellis2/vae.py | 29 +- comfy/model_detection.py | 24 +- comfy/supported_models.py | 1 + comfy_extras/nodes_model_advanced.py | 21 +- comfy_extras/nodes_trellis2.py | 540 ++++++++++++++++++++++++++- 9 files changed, 1316 insertions(+), 123 deletions(-) create mode 100644 comfy/ldm/trellis2/naf/model.py diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index d95b071b5..a09a8fca8 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -5,8 +5,12 @@ from typing import Tuple, Union, List from comfy.ldm.trellis2.vae import VarLenTensor import comfy.ops +try: + from torch.nn.attention.varlen import varlen_attn as _varlen_attn +except ImportError: + _varlen_attn = None + -# replica of the seedvr2 code def var_attn_arg(kwargs): cu_seqlens_q = kwargs.get("cu_seqlens_q", None) max_seqlen_q = kwargs.get("max_seqlen_q", None) @@ -16,42 +20,30 @@ def var_attn_arg(kwargs): return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): - var_length = True - if var_length: - cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs) - if not skip_reshape: - # assumes 2D q, k,v [total_tokens, embed_dim] - total_tokens, embed_dim = q.shape - head_dim = embed_dim // heads - q = q.view(total_tokens, heads, head_dim) - k = k.view(k.shape[0], heads, head_dim) - v = v.view(v.shape[0], heads, head_dim) + cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs) + if not skip_reshape: + total_tokens, embed_dim = q.shape + head_dim = embed_dim // heads + q = q.view(total_tokens, heads, head_dim) + k = k.view(k.shape[0], heads, head_dim) + v = v.view(v.shape[0], heads, head_dim) - b = q.size(0) - dim_head = q.shape[-1] - q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long()) - k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long()) - v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long()) - - mask = None - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - if mask is not None: - if mask.ndim == 2: - mask = mask.unsqueeze(0) - if mask.ndim == 3: - mask = mask.unsqueeze(1) - - out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) - if var_length: - return out.transpose(1, 2).values() - if not skip_output_reshape: - out = ( - out.transpose(1, 2).reshape(b, -1, heads * dim_head) + if _varlen_attn is not None: + return _varlen_attn( + q, k, v, + cu_seqlens_q, cu_seqlens_k, + int(max_seqlen_q), int(max_seqlen_k), ) - return out + + # Fallback: nested-tensor SDPA (PyTorch < the version that introduced varlen_attn) + q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long()) + k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long()) + v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long()) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + return out.transpose(1, 2).values() def scaled_dot_product_attention(*args, **kwargs): num_all_args = len(args) + len(kwargs) diff --git a/comfy/ldm/trellis2/flexgemm.py b/comfy/ldm/trellis2/flexgemm.py index 416e322ab..e22f2fe98 100644 --- a/comfy/ldm/trellis2/flexgemm.py +++ b/comfy/ldm/trellis2/flexgemm.py @@ -26,16 +26,26 @@ class TorchHashMap: self.default_value = torch.tensor(default_value, dtype=torch.long, device=device) self._n = self.sorted_keys.numel() + # Chunk size for lookup_flat. At ~530M flat keys (large mesh extraction), + # the unchunked path allocates ~5 full-size int64 temporaries (4 GB each) + + # bool masks + the int32 output. Chunking caps each transient to ~CHUNK rows. + _LOOKUP_CHUNK = 1 << 23 # 8M rows ≈ 64 MB per int64 temp + def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: - flat = flat_keys.to(torch.long) - if self._n == 0: - return torch.full((flat.shape[0],), -1, device=flat.device, dtype=torch.int32) - idx = torch.searchsorted(self.sorted_keys, flat) - idx_safe = torch.clamp(idx, max=self._n - 1) - found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat) - out = torch.full((flat.shape[0],), -1, device=flat.device, dtype=torch.int32) - if found.any(): - out[found] = self.sorted_vals[idx_safe[found]].to(torch.int32) + N = flat_keys.shape[0] + out = torch.full((N,), -1, device=flat_keys.device, dtype=torch.int32) + if self._n == 0 or N == 0: + return out + for s in range(0, N, self._LOOKUP_CHUNK): + e = min(s + self._LOOKUP_CHUNK, N) + flat_chunk = flat_keys[s:e].to(torch.long) + idx = torch.searchsorted(self.sorted_keys, flat_chunk) + in_range = idx < self._n + idx.clamp_(max=self._n - 1) # reuse idx as the "safe" index + found = in_range & (self.sorted_keys[idx] == flat_chunk) + if found.any(): + found_idx = found.nonzero(as_tuple=True)[0] + out[s + found_idx] = self.sorted_vals[idx[found_idx]].to(torch.int32) return out @@ -212,10 +222,10 @@ def sparse_submanifold_conv3d( if accumulate_f32: weight_T = weight.view(Co, V * Ci).to(torch.float32).T.contiguous() - output = torch.zeros(N_pts, Co, device=device, dtype=torch.float32) else: weight_T = weight.view(Co, V * Ci).to(feats.dtype).T.contiguous() - output = torch.zeros(N_pts, Co, device=device, dtype=feats.dtype) + + output = torch.empty(N_pts, Co, device=device, dtype=feats.dtype) # ------------------------------------------------------------------ # Chunk size from memory budget @@ -226,6 +236,9 @@ def sparse_submanifold_conv3d( chunk_size = max(1, int(max_chunk_mem / mem_per_row)) chunk_size = min(chunk_size, N_pts) + # fp32 matmul scratch — sized to the largest chunk, reused each iteration. + chunk_buf = torch.empty(chunk_size, Co, device=device, dtype=torch.float32) if accumulate_f32 else None + # ------------------------------------------------------------------ # Chunked forward pass # Each iteration: @@ -233,7 +246,8 @@ def sparse_submanifold_conv3d( # 2. mask zero invalids – in-place, no extra alloc # 3. reshape (chunk, V*Ci) # 4. GEMM (chunk, V*Ci) @ (V*Ci, Co) → (chunk, Co) – cuBLAS - # written directly into output slice via out= argument + # written into the scratch buf (fp32) or output slice (fp16) via out= + # 5. (fp32 path) cast scratch chunk to fp16 and copy into output slice # ------------------------------------------------------------------ for start in range(0, N_pts, chunk_size): end = min(start + chunk_size, N_pts) @@ -257,16 +271,13 @@ def sparse_submanifold_conv3d( gathered_flat = gathered.view(actual_chunk, V * Ci) if accumulate_f32: gathered_flat = gathered_flat.to(torch.float32) - - # Single GEMM call per chunk, written directly into output. - # This avoids allocating a temporary (chunk, Co) tensor. - torch.matmul(gathered_flat, weight_T, out=output[start:end]) - - if accumulate_f32: - output = output.to(feats.dtype) + torch.matmul(gathered_flat, weight_T, out=chunk_buf[:actual_chunk]) + output[start:end] = chunk_buf[:actual_chunk].to(feats.dtype) + else: + torch.matmul(gathered_flat, weight_T, out=output[start:end]) if bias is not None: - output = output + bias.unsqueeze(0).to(output.dtype) + output += bias.unsqueeze(0).to(output.dtype) return output, neighbor diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 3c61a5d77..0b4181ad2 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -25,15 +25,12 @@ class SparseFeedForwardNet(nn.Module): def forward(self, x: VarLenTensor) -> VarLenTensor: return self.mlp(x) -def manual_cast(obj, dtype): - return obj.to(dtype=dtype) - class LayerNorm32(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: x_dtype = x.dtype - x = manual_cast(x, torch.float32) + x = x.to(dtype=torch.float32) o = super().forward(x) - return manual_cast(o, x_dtype) + return o.to(dtype=x_dtype) class SparseMultiHeadRMSNorm(nn.Module): @@ -249,6 +246,51 @@ class SparseMultiHeadAttention(nn.Module): h = self._linear(self.to_out, h) return h +def _split_proj_context(context): + if not isinstance(context, dict): + return context, None + global_ctx = context["global"] + if "proj" in context: + return global_ctx, context["proj"] + if "proj_semantic" in context and "proj_color" in context: + return global_ctx, (context["proj_semantic"], context["proj_color"]) + return global_ctx, None + + +class ProjectAttentionSparse(nn.Module): + def __init__(self, cross_attn_block: nn.Module, channels: int, proj_in_channels: int, + device=None, dtype=None, operations=None): + super().__init__() + self.cross_attn_block = cross_attn_block + self.proj_linear = operations.Linear(proj_in_channels, channels, bias=True, + device=device, dtype=dtype) + + def forward(self, x: SparseTensor, context) -> SparseTensor: + global_ctx, proj_in = _split_proj_context(context) + global_out = self.cross_attn_block(x, global_ctx) + if isinstance(proj_in, tuple): + proj_in = torch.cat([proj_in[0], proj_in[1]], dim=-1) + proj_out = self.proj_linear(proj_in.to(self.proj_linear.weight.dtype)) + return global_out.replace(global_out.feats + proj_out.to(global_out.feats.dtype)) + + +class ProjectAttentionDense(nn.Module): + def __init__(self, cross_attn_block: nn.Module, channels: int, proj_in_channels: int, + device=None, dtype=None, operations=None): + super().__init__() + self.cross_attn_block = cross_attn_block + self.proj_linear = operations.Linear(proj_in_channels, channels, bias=True, + device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, context) -> torch.Tensor: + global_ctx, proj_in = _split_proj_context(context) + global_out = self.cross_attn_block(x, global_ctx) + if isinstance(proj_in, tuple): + proj_in = torch.cat([proj_in[0], proj_in[1]], dim=-1) + proj_out = self.proj_linear(proj_in.to(self.proj_linear.weight.dtype)) + return global_out + proj_out.to(global_out.dtype) + + class ModulatedSparseTransformerCrossBlock(nn.Module): """ Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. @@ -269,11 +311,14 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): qk_rms_norm_cross: bool = False, qkv_bias: bool = True, share_mod: bool = False, + image_attn_mode: Literal["global", "proj", "gated_proj"] = "global", + proj_in_channels: Optional[int] = None, device=None, dtype=None, operations=None ): super().__init__() self.use_checkpoint = use_checkpoint self.share_mod = share_mod + self.image_attn_mode = image_attn_mode self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device) self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) @@ -290,7 +335,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): qk_rms_norm=qk_rms_norm, device=device, dtype=dtype, operations=operations ) - self.cross_attn = SparseMultiHeadAttention( + cross_inner = SparseMultiHeadAttention( channels, ctx_channels=ctx_channels, num_heads=num_heads, @@ -300,6 +345,15 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): qk_rms_norm=qk_rms_norm_cross, device=device, dtype=dtype, operations=operations ) + if image_attn_mode == "global": + self.cross_attn = cross_inner + else: + if proj_in_channels is None: + raise ValueError("proj_in_channels must be set when image_attn_mode != 'global'") + self.cross_attn = ProjectAttentionSparse( + cross_inner, channels, proj_in_channels, + device=device, dtype=dtype, operations=operations, + ) self.mlp = SparseFeedForwardNet( channels, mlp_ratio=mlp_ratio, @@ -313,7 +367,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): else: self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5) - def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + def _forward(self, x: SparseTensor, mod: torch.Tensor, context) -> SparseTensor: if self.share_mod: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) else: @@ -324,7 +378,11 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): h = h * gate_msa x = x + h h = x.replace(self.norm2(x.feats)) - h = self.cross_attn(h, context) + if self.image_attn_mode == "global": + global_ctx, _ = _split_proj_context(context) + h = self.cross_attn(h, global_ctx) + else: + h = self.cross_attn(h, context) x = x + h h = x.replace(self.norm3(x.feats)) h = h * (1 + scale_mlp) + shift_mlp @@ -333,7 +391,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): x = x + h return x - def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + def forward(self, x: SparseTensor, mod: torch.Tensor, context) -> SparseTensor: return self._forward(x, mod, context) @@ -356,6 +414,8 @@ class SLatFlowModel(nn.Module): initialization: str = 'vanilla', qk_rms_norm: bool = False, qk_rms_norm_cross: bool = False, + image_attn_mode: Literal["global", "proj", "gated_proj"] = "global", + proj_in_channels: Optional[int] = None, dtype = None, device = None, operations = None, @@ -375,6 +435,8 @@ class SLatFlowModel(nn.Module): self.initialization = initialization self.qk_rms_norm = qk_rms_norm self.qk_rms_norm_cross = qk_rms_norm_cross + self.image_attn_mode = image_attn_mode + self.proj_in_channels = proj_in_channels self.dtype = dtype self.t_embedder = TimestepEmbedder(model_channels, device=device, dtype=dtype, operations=operations) @@ -399,6 +461,8 @@ class SLatFlowModel(nn.Module): share_mod=self.share_mod, qk_rms_norm=self.qk_rms_norm, qk_rms_norm_cross=self.qk_rms_norm_cross, + image_attn_mode=image_attn_mode, + proj_in_channels=proj_in_channels, device=device, dtype=dtype, operations=operations ) for _ in range(num_blocks) @@ -426,19 +490,15 @@ class SLatFlowModel(nn.Module): dtype = next(self.input_layer.parameters()).dtype x = x.to(dtype) h = self.input_layer(x) - h = manual_cast(h, self.dtype) t = t.to(dtype) t_embedder = self.t_embedder.to(dtype) t_emb = t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) - t_emb = manual_cast(t_emb, self.dtype) - cond = manual_cast(cond, self.dtype) for block in self.blocks: h = block(h, t_emb, cond) - h = manual_cast(h, x.dtype) h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) h = self.out_layer(h) return h @@ -561,11 +621,14 @@ class ModulatedTransformerCrossBlock(nn.Module): qk_rms_norm_cross: bool = False, qkv_bias: bool = True, share_mod: bool = False, + image_attn_mode: Literal["global", "proj", "gated_proj"] = "global", + proj_in_channels: Optional[int] = None, device=None, dtype=None, operations=None ): super().__init__() self.use_checkpoint = use_checkpoint self.share_mod = share_mod + self.image_attn_mode = image_attn_mode self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device) self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) @@ -582,7 +645,7 @@ class ModulatedTransformerCrossBlock(nn.Module): qk_rms_norm=qk_rms_norm, device=device, dtype=dtype, operations=operations ) - self.cross_attn = MultiHeadAttention( + cross_inner = MultiHeadAttention( channels, ctx_channels=ctx_channels, num_heads=num_heads, @@ -592,6 +655,15 @@ class ModulatedTransformerCrossBlock(nn.Module): qk_rms_norm=qk_rms_norm_cross, device=device, dtype=dtype, operations=operations ) + if image_attn_mode == "global": + self.cross_attn = cross_inner + else: + if proj_in_channels is None: + raise ValueError("proj_in_channels must be set when image_attn_mode != 'global'") + self.cross_attn = ProjectAttentionDense( + cross_inner, channels, proj_in_channels, + device=device, dtype=dtype, operations=operations, + ) self.mlp = FeedForwardNet( channels, mlp_ratio=mlp_ratio, @@ -605,7 +677,7 @@ class ModulatedTransformerCrossBlock(nn.Module): else: self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5) - def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context, phases: Optional[torch.Tensor] = None) -> torch.Tensor: if self.share_mod: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) else: @@ -616,7 +688,11 @@ class ModulatedTransformerCrossBlock(nn.Module): h = h * gate_msa.unsqueeze(1) x = x + h h = self.norm2(x) - h = self.cross_attn(h, context) + if self.image_attn_mode == "global": + global_ctx, _ = _split_proj_context(context) + h = self.cross_attn(h, global_ctx) + else: + h = self.cross_attn(h, context) x = x + h h = self.norm3(x) h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) @@ -625,7 +701,7 @@ class ModulatedTransformerCrossBlock(nn.Module): x = x + h return x - def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, mod: torch.Tensor, context, phases: Optional[torch.Tensor] = None) -> torch.Tensor: return self._forward(x, mod, context, phases) @@ -648,6 +724,8 @@ class SparseStructureFlowModel(nn.Module): initialization: str = 'vanilla', qk_rms_norm: bool = False, qk_rms_norm_cross: bool = False, + image_attn_mode: Literal["global", "proj", "gated_proj"] = "global", + proj_in_channels: Optional[int] = None, operations=None, device = None, dtype = torch.float32, @@ -669,6 +747,8 @@ class SparseStructureFlowModel(nn.Module): self.initialization = initialization self.qk_rms_norm = qk_rms_norm self.qk_rms_norm_cross = qk_rms_norm_cross + self.image_attn_mode = image_attn_mode + self.proj_in_channels = proj_in_channels self.dtype = dtype self.device = device @@ -703,6 +783,8 @@ class SparseStructureFlowModel(nn.Module): share_mod=share_mod, qk_rms_norm=self.qk_rms_norm, qk_rms_norm_cross=self.qk_rms_norm_cross, + image_attn_mode=image_attn_mode, + proj_in_channels=proj_in_channels, device=device, dtype=dtype, operations=operations ) for _ in range(num_blocks) @@ -720,14 +802,9 @@ class SparseStructureFlowModel(nn.Module): t_emb = self.t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) - t_emb = manual_cast(t_emb, self.dtype) - h = manual_cast(h, self.dtype) - cond = manual_cast(cond, self.dtype) for block in self.blocks: h = block(h, t_emb, cond, self.rope_phases) - h = manual_cast(h, x.dtype) h = F.layer_norm(h, h.shape[-1:]) - h = h.to(next(self.out_layer.parameters()).dtype) h = self.out_layer(h) h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous() @@ -741,6 +818,221 @@ def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0): t_new *= 1000.0 return t_new + +# Pixal3D ProjGrid math — port of upstream's ProjGrid + project_points_to_image_batch. +# World frame uses world Y as depth (Blender convention), camera looks along -Z local; +# transform_matrix is camera-to-world (inverted internally). Intrinsics: fx = 16 / tan(fov/2) +# with sensor_width = 32mm. + +_PROJ_GRID_ROTATION = torch.tensor( + [[1.0, 0.0, 0.0], + [0.0, 0.0, -1.0], + [0.0, 1.0, 0.0]] +) + +_PROJ_FRONT_VIEW_TRANSFORM = torch.tensor( + [[1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, -1.0, -2.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0]] +) + + +def _build_proj_transform_matrix(distance: torch.Tensor, batch_size: int, + device, dtype=torch.float32) -> torch.Tensor: + T = _PROJ_FRONT_VIEW_TRANSFORM.to(device=device, dtype=dtype) + T = T.unsqueeze(0).expand(batch_size, -1, -1).clone() + if distance.ndim == 0: + distance = distance.expand(batch_size) + T[:, 1, 3] = -distance.to(device=device, dtype=dtype) + return T + + +def _project_points_to_image(points_world: torch.Tensor, transform_matrix: torch.Tensor, + camera_angle_x: torch.Tensor, resolution: int): + B, N, _ = points_world.shape + ones = torch.ones((B, N, 1), device=points_world.device, dtype=points_world.dtype) + homo = torch.cat([points_world, ones], dim=-1) + world_to_camera = torch.linalg.inv(transform_matrix.float()).to(transform_matrix.dtype) + p_cam = torch.bmm(homo, world_to_camera.transpose(-2, -1))[..., :3] + x_cam, y_cam, z_cam = p_cam.unbind(dim=-1) + depth = -z_cam + sensor_width = 32.0 + focal_length = 16.0 / torch.tan(camera_angle_x / 2.0) + focal_px = focal_length * resolution / sensor_width + focal_px = focal_px.to(p_cam.dtype).unsqueeze(1) + denom = (-z_cam + 1e-8) + x_pix = focal_px * x_cam / denom + resolution / 2.0 + y_pix = -focal_px * y_cam / denom + resolution / 2.0 + valid = ((x_pix >= 0) & (x_pix < resolution) & + (y_pix >= 0) & (y_pix < resolution) & (depth > 0)) + return torch.stack([x_pix, y_pix], dim=-1), depth, valid + + +def _sample_features(feature_map: torch.Tensor, uv_ndc: torch.Tensor) -> torch.Tensor: + B, C, _, _ = feature_map.shape + grid = uv_ndc.view(B, -1, 1, 2).to(feature_map.dtype) + feat = F.grid_sample(feature_map, grid, mode="bilinear", + padding_mode="border", align_corners=False) + return feat.squeeze(-1) + + +def _coords_to_proj_world(coords: torch.Tensor, resolution: int, mesh_scale: torch.Tensor): + if resolution < 1: + raise ValueError(f"resolution must be positive, got {resolution}") + batch_ids = coords[:, 0].long() + if resolution == 1: + norm = coords[:, 1:].to(torch.float32) * 0.0 + else: + norm = coords[:, 1:].to(torch.float32) / (resolution - 1) * 2.0 - 1.0 + R = _PROJ_GRID_ROTATION.to(device=coords.device, dtype=torch.float32) + rotated = norm @ R.T + if mesh_scale.ndim == 0: + scale_per_voxel = mesh_scale.expand(coords.shape[0]) + else: + scale_per_voxel = mesh_scale.to(coords.device)[batch_ids] + world = rotated / scale_per_voxel.unsqueeze(-1) / 2.0 + return world, batch_ids + + +def _dense_grid_proj_world(resolution: int, mesh_scale: torch.Tensor, + batch_size: int, device, dtype=torch.float32) -> torch.Tensor: + one = torch.linspace(-1.0, 1.0, resolution, device=device, dtype=dtype) + x, y, z = torch.meshgrid(one, one, one, indexing="ij") + grid = torch.stack([x, y, z], dim=-1).reshape(-1, 3) + R_rot = _PROJ_GRID_ROTATION.to(device=device, dtype=dtype) + grid = grid @ R_rot.T + grid = grid.unsqueeze(0).expand(batch_size, -1, -1).clone() + if mesh_scale.ndim == 0: + mesh_scale = mesh_scale.expand(batch_size) + grid = grid / mesh_scale.to(device=device, dtype=dtype).view(-1, 1, 1) / 2.0 + return grid + + +def _back_project_to_tokens( + coords_world: torch.Tensor, + feature_map: torch.Tensor, + transform_matrix: torch.Tensor, + camera_angle_x: torch.Tensor, + image_resolution: int, + batch_ids: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if coords_world.dim() == 2: + assert batch_ids is not None + B = transform_matrix.shape[0] + out = torch.zeros((coords_world.shape[0], feature_map.shape[1]), + device=feature_map.device, dtype=feature_map.dtype) + for b in range(B): + mask = batch_ids == b + if not mask.any(): + continue + p = coords_world[mask].unsqueeze(0) + uv, depth, valid = _project_points_to_image( + p, transform_matrix[b:b+1], camera_angle_x[b:b+1], image_resolution) + uv_ndc = (uv + 0.5) / image_resolution * 2.0 - 1.0 + # padding_mode='border' is load-bearing: masking out-of-frame voxels confuses + # the SS DiT (~half the voxels go to zero, producing low poly + rotation drift). + sampled = _sample_features(feature_map[b:b+1], uv_ndc) + sampled = sampled.squeeze(0).transpose(0, 1) + out[mask] = sampled + return out + else: + uv, depth, valid = _project_points_to_image( + coords_world, transform_matrix, camera_angle_x, image_resolution) + uv_ndc = (uv + 0.5) / image_resolution * 2.0 - 1.0 + sampled = _sample_features(feature_map, uv_ndc) + out = sampled.transpose(1, 2) + return out + + +def _pack_per_voxel_scalar(proj_pack: Optional[dict], key: str, eval_batch: int, device) -> torch.Tensor: + if proj_pack is None or key not in proj_pack: + return torch.ones((eval_batch,), device=device, dtype=torch.float32) + t = proj_pack[key].to(device=device, dtype=torch.float32) + if t.ndim == 0: + return t.expand(eval_batch).clone() + return _expand_pack(t, eval_batch) + + +def _expand_pack(t: torch.Tensor, eval_batch: int) -> torch.Tensor: + if eval_batch == t.shape[0]: + return t + if eval_batch % t.shape[0] != 0: + raise ValueError(f"eval batch {eval_batch} is not a multiple of pack batch {t.shape[0]}") + return t.repeat((eval_batch // t.shape[0],) + (1,) * (t.ndim - 1)) + + +def _select_stage_entry(proj_pack: dict, stage: Optional[str]): + """Returns (feature_map_lr, feature_map_hr_or_None, image_resolution).""" + stages = proj_pack.get("stages") + if stages is not None and stage is not None and stage in stages: + entry = stages[stage] + return entry["feature_map"], entry.get("feature_map_hr"), int(entry.get("image_resolution", 1024)) + if "feature_map" in proj_pack: + return proj_pack["feature_map"], proj_pack.get("feature_map_hr"), int(proj_pack.get("image_resolution", 1024)) + raise ValueError(f"proj_feat_pack has no usable feature_map (stage={stage!r})") + + +def _build_proj_cond(global_cond: torch.Tensor, image_attn_mode: str, proj_pack: Optional[dict], + coords_world: torch.Tensor, batch_ids: Optional[torch.Tensor] = None, + eval_batch: Optional[int] = None, + proj_in_channels: Optional[int] = None, + stage: Optional[str] = None, + cond_or_uncond: Optional[list] = None): + if image_attn_mode == "global": + return global_cond + if proj_pack is None: + raise ValueError(f"image_attn_mode={image_attn_mode!r} but proj_feat_pack is missing") + device = coords_world.device + T = proj_pack["transform_matrix"].to(device) + cam_angle = proj_pack["camera_angle_x"].to(device) + feat_map_lr, feat_map_hr, image_resolution = _select_stage_entry(proj_pack, stage) + feat_map_lr = feat_map_lr.to(device) + if feat_map_hr is not None: + feat_map_hr = feat_map_hr.to(device) + if eval_batch is not None: + T = _expand_pack(T, eval_batch) + cam_angle = _expand_pack(cam_angle, eval_batch) if cam_angle.ndim >= 1 else cam_angle + feat_map_lr = _expand_pack(feat_map_lr, eval_batch) + if feat_map_hr is not None: + feat_map_hr = _expand_pack(feat_map_hr, eval_batch) + # Channel-count check against the trained proj_linear input. If HR is present, the + # block expects (LR_channels + HR_channels) since we concat the sampled features. + expected_channels = feat_map_lr.shape[1] + (feat_map_hr.shape[1] if feat_map_hr is not None else 0) + if proj_in_channels is not None and expected_channels != proj_in_channels: + hint = "" + if feat_map_hr is None and expected_channels < proj_in_channels: + hint = (" — feature_map_hr is missing for this stage. Connect a NAFModel " + "input to Pixal3DConditioning; the shape/texture stages of this " + "checkpoint need a NAF-upsampled HR feature map.") + raise ValueError( + f"proj_feat_pack[{stage!r}] has LR={feat_map_lr.shape[1]} " + f"+ HR={feat_map_hr.shape[1] if feat_map_hr is not None else 0} " + f"= {expected_channels} channels, sub-model expects {proj_in_channels}.{hint}" + ) + proj_feats_lr = _back_project_to_tokens(coords_world, feat_map_lr, T, cam_angle, + image_resolution=image_resolution, + batch_ids=batch_ids) + if feat_map_hr is not None: + proj_feats_hr = _back_project_to_tokens(coords_world, feat_map_hr, T, cam_angle, + image_resolution=image_resolution, + batch_ids=batch_ids) + proj_feats = torch.cat([proj_feats_lr, proj_feats_hr], dim=-1) + else: + proj_feats = proj_feats_lr + # Mirror upstream's neg_cond by zeroing proj for any uncond batch slot. + if cond_or_uncond is not None and eval_batch is not None: + uncond_slots = [i for i, v in enumerate(cond_or_uncond) if v == 1] + if uncond_slots: + uncond_idx = torch.tensor(uncond_slots, device=proj_feats.device, dtype=torch.long) + if batch_ids is None: + proj_feats = proj_feats.clone() + proj_feats[uncond_idx] = 0 + else: + neg_mask = torch.isin(batch_ids, uncond_idx).unsqueeze(-1).to(proj_feats.dtype) + proj_feats = proj_feats * (1.0 - neg_mask) + return {"global": global_cond, "proj": proj_feats} + class Trellis2(nn.Module): def __init__(self, resolution, in_channels = 32, @@ -754,6 +1046,12 @@ class Trellis2(nn.Module): qk_rms_norm = True, qk_rms_norm_cross = True, init_txt_model=False, # for now + image_attn_mode_structure: str = "global", + proj_in_channels_structure: Optional[int] = None, + image_attn_mode_shape: str = "global", + proj_in_channels_shape: Optional[int] = None, + image_attn_mode_texture: str = "global", + proj_in_channels_texture: Optional[int] = None, dtype=None, device=None, operations=None, **kwargs): super().__init__() @@ -767,22 +1065,29 @@ class Trellis2(nn.Module): "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations } + self.image_attn_mode_structure = image_attn_mode_structure + self.image_attn_mode_shape = image_attn_mode_shape + self.image_attn_mode_texture = image_attn_mode_texture + shape_proj_kwargs = {"image_attn_mode": image_attn_mode_shape, "proj_in_channels": proj_in_channels_shape} + tex_proj_kwargs = {"image_attn_mode": image_attn_mode_texture, "proj_in_channels": proj_in_channels_texture} + struct_proj_kwargs = {"image_attn_mode": image_attn_mode_structure, "proj_in_channels": proj_in_channels_structure} txt_only = kwargs.get("txt_only", False) if not txt_only: - self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) + self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **shape_proj_kwargs, **args) self.shape2txt = None if init_txt_model: - self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) - self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args) + self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **tex_proj_kwargs, **args) + self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **shape_proj_kwargs, **args) args.pop("out_channels") - self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) + self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **struct_proj_kwargs, **args) else: - self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) + self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **tex_proj_kwargs, **args) self.guidance_interval = [0.6, 1.0] self.guidance_interval_txt = [0.6, 0.9] def forward(self, x, timestep, context, **kwargs): transformer_options = kwargs.get("transformer_options", {}) + cond_or_uncond = transformer_options.get("cond_or_uncond") model_options = {} if hasattr(self, "meta"): model_options = self.meta @@ -795,6 +1100,8 @@ class Trellis2(nn.Module): coords = model_options.get("coords", None) coord_counts = model_options.get("coord_counts", None) mode = model_options.get("generation_mode", "structure_generation") + proj_feat_pack = model_options.get("proj_feat_pack", None) + coord_resolution = model_options.get("coord_resolution", None) is_512_run = False if mode == "shape_generation_512": @@ -884,6 +1191,20 @@ class Trellis2(nn.Module): x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) if mode == "shape_generation": + shape_attn = self.image_attn_mode_shape + if shape_attn != "global": + if coord_resolution is None: + raise ValueError("Pixal3D shape_generation requires coord_resolution in model_options; " + "EmptyTrellis2ShapeLatent should set it from the input voxel.") + mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", B, batched_coords.device) + xyz_world, batch_ids = _coords_to_proj_world(batched_coords, coord_resolution, mesh_scale) + sub_model = self.img2shape_512 if is_512_run else self.img2shape + stage_name = "shape_512" if is_512_run else "shape_1024" + c_eval = _build_proj_cond(c_eval, shape_attn, proj_feat_pack, xyz_world, batch_ids, + eval_batch=B, + proj_in_channels=sub_model.proj_in_channels, + stage=stage_name, + cond_or_uncond=cond_or_uncond) if is_512_run: out = self.img2shape_512(x_st, t_eval, c_eval) else: @@ -904,18 +1225,49 @@ class Trellis2(nn.Module): slat_feats = slat_feats[:N].repeat(B, 1) x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats.to(x_st.feats.device)], dim=-1)) + tex_attn = self.image_attn_mode_texture + if tex_attn != "global": + if coord_resolution is None: + raise ValueError("Pixal3D texture_generation requires coord_resolution in model_options; " + "EmptyTrellis2LatentTexture should set it from the input voxel.") + mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", B, batched_coords.device) + xyz_world, batch_ids = _coords_to_proj_world(batched_coords, coord_resolution, mesh_scale) + c_eval = _build_proj_cond(c_eval, tex_attn, proj_feat_pack, xyz_world, batch_ids, + eval_batch=B, + proj_in_channels=self.shape2txt.proj_in_channels, + stage="tex_1024", + cond_or_uncond=cond_or_uncond) out = self.shape2txt(x_st, t_eval, c_eval) else: # structure orig_bsz = x.shape[0] + struct_attn = self.image_attn_mode_structure if shape_rule and orig_bsz > 1: half = orig_bsz // 2 x_eval = x[half:] t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep - out = self.structure_model(x_eval, t_eval, cond) + struct_cond = cond + if struct_attn != "global": + mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", half, x.device) + grid_xyz = _dense_grid_proj_world(16, mesh_scale, half, device=x.device) + struct_cond = _build_proj_cond(cond, struct_attn, proj_feat_pack, grid_xyz, + eval_batch=half, + proj_in_channels=self.structure_model.proj_in_channels, + stage="ss", + cond_or_uncond=cond_or_uncond) + out = self.structure_model(x_eval, t_eval, struct_cond) out = out.repeat(2, 1, 1, 1, 1) else: - out = self.structure_model(x, timestep, context) + struct_cond = context + if struct_attn != "global": + mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", orig_bsz, x.device) + grid_xyz = _dense_grid_proj_world(16, mesh_scale, orig_bsz, device=x.device) + struct_cond = _build_proj_cond(context, struct_attn, proj_feat_pack, grid_xyz, + eval_batch=orig_bsz, + proj_in_channels=self.structure_model.proj_in_channels, + stage="ss", + cond_or_uncond=cond_or_uncond) + out = self.structure_model(x, timestep, struct_cond) if not_struct_mode: if mask is not None: diff --git a/comfy/ldm/trellis2/naf/model.py b/comfy/ldm/trellis2/naf/model.py new file mode 100644 index 000000000..1b085e7e7 --- /dev/null +++ b/comfy/ldm/trellis2/naf/model.py @@ -0,0 +1,301 @@ +"""NAF (Neighborhood Attention Filtering) feature upsampler. + +Vendored from valeoai/NAF (Apache-2.0): + https://github.com/valeoai/NAF — src/model/naf.py + src/layers/{convolutions,attentions,rope}.py +Used by Pixal3D's shape/texture conditioning to produce +the 2x-upsampled half of the 2048-channel proj feature map. +""" + +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# Pure-torch neighborhood attention (replaces natten.na2d / na2d_qk + na2d_av). + +def upsample_lr_slice(src_lr: torch.Tensor, lr_dh: int, lr_dw: int, + hr_h_range: Tuple[int, int], hr_w_range: Tuple[int, int]) -> torch.Tensor: + """Slice a LR-layout tensor [B, h_lr, w_lr, n, C], permute to BCHW, and + nearest-exact upsample only the region covering [hr_h_range, hr_w_range]. + Returns BCHW at hr_h_end-hr_h_start x hr_w_end-hr_w_start (no padding for + out-of-bounds regions).""" + B = src_lr.shape[0] + n = src_lr.shape[-2] + C = src_lr.shape[-1] + h_hr_start, h_hr_end = hr_h_range + w_hr_start, w_hr_end = hr_w_range + # LR positions covering [h_hr_start, h_hr_end). Nearest-exact maps HR p → p // D. + lr_h_start = h_hr_start // lr_dh + lr_h_end = (h_hr_end - 1) // lr_dh + 1 + lr_w_start = w_hr_start // lr_dw + lr_w_end = (w_hr_end - 1) // lr_dw + 1 + lr_slice = src_lr[:, lr_h_start:lr_h_end, lr_w_start:lr_w_end] + lh, lw = lr_slice.shape[1], lr_slice.shape[2] + lr_bcd = lr_slice.permute(0, 3, 4, 1, 2).reshape(B * n, C, lh, lw).contiguous() + up = F.interpolate(lr_bcd, scale_factor=(lr_dh, lr_dw), mode="nearest-exact") + offset_h = h_hr_start - lr_h_start * lr_dh + offset_w = w_hr_start - lr_w_start * lr_dw + return up[:, :, offset_h:offset_h + (h_hr_end - h_hr_start), + offset_w:offset_w + (w_hr_end - w_hr_start)] + + +def na2d_pure( + q: torch.Tensor, # [B, H, W, n_heads, d_qk] at HR. + k_lr: torch.Tensor, # [B, h_lr, w_lr, n_heads, d_qk] at LR + v_lr: torch.Tensor, # [B, h_lr, w_lr, n_heads, d_v] at LR + kernel_size: Tuple[int, int], # (Kh, Kw) attention window. + dilation: Tuple[int, int], # (Dh, Dw) stride within the unrolled K/V grid; also the LR→HR upsample factor. + scale: float, # 1 / sqrt(d_qk) scaling for the Q·K scores. + tile: int = 128, # Spatial tile size (output positions per tile) + v_chunk: int = 64 # Sub-divide d_v into chunks of this size when computing attn·V. None disables chunking. + ) -> torch.Tensor: # [B, H, W, n_heads, d_v] attended features. + """Neighborhood attention in pure torch via F.unfold + per-tile slicing. + + K and V are passed at LR resolution and upsampled (nearest-exact) per-tile only + for the slice the unfold needs. Avoids the [B, n*d, H, W] HR allocations for K + (512 MB) and V (2 GB) at tex_1024 fp16. Spatial tiling bounds the per-tile + F.unfold blob; `v_chunk` further slices d_v so attn·V is computed in C-sized + chunks (attn is reused, computed once from Q/K). + + """ + B, H, W, n, d_qk = q.shape + d_v = v_lr.shape[-1] + Kh, Kw = kernel_size + Dh, Dw = dilation + pad_h, pad_w = (Kh // 2) * Dh, (Kw // 2) * Dw + + q_ = q.permute(0, 3, 4, 1, 2).contiguous() # [B, n, d_qk, H, W] + out = torch.empty((B, n, d_v, H, W), device=q.device, dtype=q.dtype) + + th = min(tile, H) if tile else H + tw = min(tile, W) if tile else W + chunk = v_chunk if (v_chunk and v_chunk < d_v) else d_v + + for h0 in range(0, H, th): + for w0 in range(0, W, tw): + h1, w1 = min(h0 + th, H), min(w0 + tw, W) + t_h, t_w = h1 - h0, w1 - w0 + + # Padded HR region the unfold needs (kernel span = (K-1)*D + 1). + h_src_start = max(0, h0 - pad_h) + h_src_end = min(H, h1 + pad_h) + w_src_start = max(0, w0 - pad_w) + w_src_end = min(W, w1 + pad_w) + pad_top = max(0, pad_h - h0) + pad_bot = max(0, (h1 + pad_h) - H) + pad_lft = max(0, pad_w - w0) + pad_rgt = max(0, (w1 + pad_w) - W) + + # Upsample only the tile region from k_lr / v_lr. + k_tile = upsample_lr_slice(k_lr, Dh, Dw, + (h_src_start, h_src_end), + (w_src_start, w_src_end)) + v_tile = upsample_lr_slice(v_lr, Dh, Dw, + (h_src_start, h_src_end), + (w_src_start, w_src_end)) + if pad_top or pad_bot or pad_lft or pad_rgt: + k_tile = F.pad(k_tile, [pad_lft, pad_rgt, pad_top, pad_bot]) + v_tile = F.pad(v_tile, [pad_lft, pad_rgt, pad_top, pad_bot]) + + # Q·K → attention weights (small: KK=81 per output position). + KK = Kh * Kw + k_w = F.unfold(k_tile, kernel_size=(Kh, Kw), dilation=(Dh, Dw), padding=0) + k_w = k_w.view(B, n, d_qk, KK, t_h * t_w).permute(0, 1, 4, 3, 2) # [B, n, t, KK, d_qk] + q_tile = q_[:, :, :, h0:h1, w0:w1].permute(0, 1, 3, 4, 2).reshape(B, n, t_h * t_w, 1, d_qk) + scores = torch.matmul(q_tile, k_w.transpose(-1, -2)) * scale + attn = scores.softmax(dim=-1) + del k_w, scores, q_tile, k_tile + + # attn · V, chunked over d_v. + for c0 in range(0, d_v, chunk): + c1 = min(c0 + chunk, d_v) + v_w = F.unfold(v_tile[:, c0:c1], kernel_size=(Kh, Kw),dilation=(Dh, Dw), padding=0) # [B*n, (c1-c0)*KK, t] + v_w = v_w.view(B, n, c1 - c0, KK, t_h * t_w).permute(0, 1, 4, 3, 2) + out_chunk = torch.matmul(attn, v_w).squeeze(-2) # [B, n, t, c1-c0] + out_chunk = out_chunk.view(B, n, t_h, t_w, c1 - c0).permute(0, 1, 4, 2, 3) + out[:, :, c0:c1, h0:h1, w0:w1] = out_chunk + del v_w, out_chunk + del attn, v_tile + + return out.permute(0, 3, 4, 1, 2).contiguous() # [B, H, W, n, d_v] + + +class CrossAttention(nn.Module): + """Window-restricted cross-attention. No learnable parameters; the model's + capacity lives entirely in the ImageEncoder convs.""" + + def __init__(self, dim: int, num_heads: int, kernel_size: Tuple[int, int] = (9, 9)): + super().__init__() + assert dim % num_heads == 0, "dim must be divisible by num_heads" + self.num_heads = num_heads + self.kernel_size = kernel_size + self.scale = (dim // num_heads) ** -0.5 + + @staticmethod + def _split_heads_lr(x: torch.Tensor, num_heads: int) -> torch.Tensor: + """[B, n*d, h, w] -> [B, h, w, n, d] at the input resolution (no upsample).""" + B, C, H, W = x.shape + return x.view(B, num_heads, C // num_heads, H, W).permute(0, 3, 4, 1, 2).contiguous() + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + # q is [B, C, Hq, Wq] at HR; k and v are at LR (Hk, Wk). We KEEP k and v at LR + # na2d_pure upsamples only the tile slice it needs. + hq, wq = q.shape[-2:] + hk, wk = k.shape[-2:] + dilation = (hq // hk, wq // wk) + B, C, _, _ = q.shape + q = q.view(B, self.num_heads, C // self.num_heads, hq, wq).permute(0, 3, 4, 1, 2).contiguous() + k_lr = self._split_heads_lr(k, self.num_heads).to(q.dtype) + v_lr = self._split_heads_lr(v, self.num_heads).to(q.dtype) + out = na2d_pure(q, k_lr, v_lr, self.kernel_size, dilation, self.scale) + # [B, H, W, n, d] -> [B, n*d, H, W] + return out.permute(0, 3, 4, 1, 2).contiguous().view(B, -1, hq, wq) + + +# RoPE positional embedding + +def rope_rotate_half(x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat([-x2, x1], dim=-1) + + +class RoPE(nn.Module): + def __init__(self, embed_dim: int, num_heads: int, base: float = 100.0): + super().__init__() + assert embed_dim % (4 * num_heads) == 0 + self.num_heads = num_heads + self.D_head = embed_dim // num_heads + self.base = base + self.register_buffer("periods", torch.empty(self.D_head // 4), persistent=True) # loaded from the checkpoint + self._cached_key = None + self._cached_cos_sin = None + + def _cos_sin(self, H: int, W: int, dtype: torch.dtype): + """cos/sin only depend on (H, W) and the output dtype (periods are fixed + once loaded from the checkpoint), so cache them — saves the meshgrid / + angle / cos / sin / tile / flatten on every forward.""" + key = (H, W, dtype) + if self._cached_key == key and self._cached_cos_sin is not None: + return self._cached_cos_sin + device = self.periods.device + coords_h = torch.arange(0.5, H, device=device, dtype=torch.float32) / H + coords_w = torch.arange(0.5, W, device=device, dtype=torch.float32) / W + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # [H, W, 2] + coords = coords.flatten(0, 1) * 2.0 - 1.0 # [HW, 2] + angles = 2 * math.pi * coords[:, :, None] / self.periods.to(coords.dtype)[None, None, :] # [HW, 2, D//4] + angles = angles.flatten(1, 2).tile(2) # [HW, D] + cos = torch.cos(angles).to(dtype) + sin = torch.sin(angles).to(dtype) + self._cached_cos_sin = (cos, sin) + self._cached_key = key + return cos, sin + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [B, n*D_head, H, W] + B, C, H, W = x.shape + n = self.num_heads + D = C // n + x = x.view(B, n, D, H, W).permute(0, 1, 3, 4, 2).reshape(B, n, H * W, D) + cos, sin = self._cos_sin(H, W, x.dtype) + x = (x * cos) + (rope_rotate_half(x) * sin) + x = x.view(B, n, H, W, D).permute(0, 1, 4, 2, 3).reshape(B, n * D, H, W) + return x + + +# Image encoder + +class EncBlock(nn.Module): + def __init__(self, channels: int, kernel_size: int, num_groups: int = 8): + super().__init__() + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=channels) + self.conv1 = nn.Conv2d(channels, channels, kernel_size=kernel_size, + padding=kernel_size // 2, padding_mode="reflect", bias=True) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=channels) + self.conv2 = nn.Conv2d(channels, channels, kernel_size=kernel_size, + padding=kernel_size // 2, padding_mode="reflect", bias=True) + self.activation_fn = nn.SiLU() + + def forward(self, x): + x = self.norm1(x) + x = self.activation_fn(x) + x = self.conv1(x) + x = self.norm2(x) + x = self.activation_fn(x) + x = self.conv2(x) + return x # no skip connection + + +def _encoder(in_dim: int, hidden_dim: int, kernel_size: int = 1, ks_res: int = 1, num_layers: int = 2) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(in_dim, hidden_dim, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode="reflect", bias=True), + *[EncBlock(hidden_dim, kernel_size=ks_res) for _ in range(num_layers)], + ) + + +class ImageEncoder(nn.Module): + """Two parallel conv stacks (1x1 + 3x3) producing dim/2 channels each, then concat, + spatial average-pool to target size, RoPE-embed positions.""" + + def __init__(self, in_channels: int = 3, out_channels: int = 256, + heads_rope: int = 4, rope_base: float = 100.0, img_layers: int = 2): + super().__init__() + half = out_channels // 2 + self.encoder = _encoder(in_channels, half, kernel_size=1, ks_res=1, num_layers=img_layers) + self.sem_encoder = _encoder(in_channels, half, kernel_size=3, ks_res=3, num_layers=img_layers) + self.rope = RoPE(embed_dim=out_channels, num_heads=heads_rope, base=rope_base) + + def forward(self, x: torch.Tensor, output_size: Tuple[int, int]) -> torch.Tensor: + # Avoid running the conv stacks on >4× the target resolution. + out_h, out_w = output_size + if x.shape[-2] > 4 * out_h or x.shape[-1] > 4 * out_w: + x = F.interpolate(x, size=(min(x.shape[-2], 4 * out_h), + min(x.shape[-1], 4 * out_w)), + mode="bilinear", align_corners=False) + x = torch.cat([self.encoder(x), self.sem_encoder(x)], dim=1) + x = F.adaptive_avg_pool2d(x, output_size=output_size) + x = self.rope(x) + return x + + +# Top-level NAF model. + +class NAF(nn.Module): + """NAF feature upsampler.""" + + def __init__( + self, dim: int = 256, # internal channel dimension of the ImageEncoder + heads_attn: int = 4, # attention heads in the windowed cross-attn + heads_rope: int = 4, # heads for RoPE position encoding (must divide dim) + kernel_size: int = 9, # square kernel for the neighborhood attention window + rope_base: float = 100.0, # base for RoPE frequency periods + img_layers: int = 2 # number of EncBlocks in each conv stack + ): + super().__init__() + self.image_encoder = ImageEncoder(in_channels=3, out_channels=dim, heads_rope=heads_rope, rope_base=rope_base, img_layers=img_layers) + self.upsampler = CrossAttention(dim=dim, num_heads=heads_attn, kernel_size=(kernel_size, kernel_size)) + + def forward( + self, + image: torch.Tensor, # [B, 3, H_img, W_img] in [0, 1]. + features: torch.Tensor, # [B, C, H_feat, W_feat] low-resolution features (any C). + output_size: Tuple[int, int] # (H_out, W_out) target spatial resolution for the upsampled features. + ) -> torch.Tensor: # [B, C, H_out, W_out] upsampled features. + """Upsample low-res feature map to output_size, guided by the image.""" + q = self.image_encoder(image, output_size=output_size) + k = F.adaptive_avg_pool2d(q, output_size=features.shape[-2:]) + return self.upsampler(q, k, features) + + +def build_naf_from_state_dict(state_dict: dict) -> NAF: + """Instantiate NAF with the default hyperparams and load the given state_dict. + + The published NAF release uses the default constructor (dim=256, heads_attn=4, + heads_rope=4, kernel_size=9, rope_base=100, img_layers=2).""" + model = NAF() + missing, unexpected = model.load_state_dict(state_dict, strict=False) + if unexpected: + raise ValueError(f"Unexpected keys in NAF state_dict: {sorted(unexpected)[:8]}...") + return model diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 5336f4dc7..d1d482814 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -75,13 +75,9 @@ def sparse_conv3d_forward(self, x): class LayerNorm32(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dtype = x.dtype - x = x.to(torch.float32) - w = self.weight.to(torch.float32) if self.weight is not None else None - b = self.bias.to(torch.float32) if self.bias is not None else None - - o = F.layer_norm(x, self.normalized_shape, w, b, self.eps) - return o.to(x_dtype) + w = self.weight.to(x.dtype) if self.weight is not None else None + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.layer_norm(x, self.normalized_shape, w, b, self.eps) class SparseConvNeXtBlock3d(nn.Module): def __init__( @@ -204,7 +200,6 @@ class SparseResBlockC2S3d(nn.Module): self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) self.conv1 = SparseConv3d(channels, self.out_channels * 8, 3) self.conv2 = SparseConv3d(self.out_channels, self.out_channels, 3) - self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1)) if pred_subdiv: self.to_subdiv = SparseLinear(channels, 8) self.updown = SparseChannel2Spatial(2) @@ -215,15 +210,16 @@ class SparseResBlockC2S3d(nn.Module): x = x.to(dtype) subdiv = self.to_subdiv(x) h = x.replace(self.norm1(x.feats)) - h = h.replace(F.silu(h.feats)) + h = h.replace(F.silu(h.feats, inplace=True)) h = self.conv1(h) subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None h = self.updown(h, subdiv_binarized) x = self.updown(x, subdiv_binarized) h = h.replace(self.norm2(h.feats)) - h = h.replace(F.silu(h.feats)) + h = h.replace(F.silu(h.feats, inplace=True)) h = self.conv2(h) - h = h + self.skip_connection(x) + skip_repeat = self.out_channels // (self.channels // 8) + h.feats.view(h.feats.shape[0], x.feats.shape[1], skip_repeat).add_(x.feats.unsqueeze(-1)) if self.pred_subdiv: return h, subdiv else: @@ -1211,13 +1207,12 @@ def flexible_dual_grid_to_mesh( edge_neighbor_voxel = coords.reshape(N, 1, 1, 3) + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset # (N, 3, 4, 3) connected_voxel = edge_neighbor_voxel[intersected_flag] # (M, 4, 3) M = connected_voxel.shape[0] - # flatten connected voxel coords and lookup - conn_flat_b = torch.zeros((M * 4,), dtype=torch.long, device=coords.device) - conn_x = connected_voxel.reshape(-1, 3)[:, 0].to(torch.int32) - conn_y = connected_voxel.reshape(-1, 3)[:, 1].to(torch.int32) - conn_z = connected_voxel.reshape(-1, 3)[:, 2].to(torch.int32) + # flatten connected voxel coords and lookup. In-place to avoid extra memory allocation. W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) - conn_flat = conn_flat_b * (W * H * D) + conn_x * (H * D) + conn_y * D + conn_z + cv = connected_voxel.reshape(-1, 3) + conn_flat = cv[:, 0].long() * (H * D) + conn_flat.add_(cv[:, 1].long() * D) + conn_flat.add_(cv[:, 2].long()) conn_indices = torch_hashmap.lookup_flat(conn_flat).reshape(M, 4).int() connected_voxel_valid = (conn_indices != 0xffffffff).all(dim=1) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 3baf5c501..327134d68 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -113,13 +113,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] return unet_config - if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: + def _detect_proj(sub_prefix: str, name: str): + key = '{}{}.blocks.0.cross_attn.proj_linear.weight'.format(key_prefix, sub_prefix) + if key in state_dict_keys: + unet_config["image_attn_mode_{}".format(name)] = "proj" + unet_config["proj_in_channels_{}".format(name)] = int(state_dict[key].shape[1]) + + if '{}img2shape.blocks.0.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or \ + '{}img2shape.blocks.0.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: unet_config = {} unet_config["image_model"] = "trellis2" - unet_config["init_txt_model"] = False - if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: - unet_config["init_txt_model"] = True + unet_config["init_txt_model"] = ( + '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or + '{}shape2txt.blocks.29.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys + ) unet_config["resolution"] = 64 if metadata is not None: @@ -127,14 +135,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config["resolution"] = 32 unet_config["num_heads"] = 12 + + _detect_proj("img2shape", "shape") + _detect_proj("shape2txt", "texture") + _detect_proj("structure_model", "structure") return unet_config - if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: # trellis2 texture + if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or \ + '{}shape2txt.blocks.29.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: # trellis2 texture unet_config = {} unet_config["image_model"] = "trellis2" unet_config["resolution"] = 64 unet_config["num_heads"] = 12 unet_config["txt_only"] = True + _detect_proj("shape2txt", "texture") return unet_config if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit diff --git a/comfy/supported_models.py b/comfy/supported_models.py index f02a370f0..86338f1d7 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1325,6 +1325,7 @@ class Trellis2(supported_models_base.BASE): sampling_settings = { "shift": 3.0, + "multiplier": 1.0 } memory_usage_factor = 3.5 diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index b27ac1296..16919d149 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -276,13 +276,30 @@ class RescaleCFG: CATEGORY = "advanced/model" def patch(self, model, multiplier): + model_sampling = model.get_model_object("model_sampling") + is_x0_space = not isinstance(model_sampling, comfy.model_sampling.EPS) + def rescale_cfg(args): + x_orig = args["input"] + cond_scale = args["cond_scale"] + + if is_x0_space: + # Flow-matching / X0 models: cond_denoised/uncond_denoised are x_0 estimates, + # so the eps↔v conversion below would be wrong. Rescale directly in x_0 space. + x_0_cond = args["cond_denoised"] + x_0_uncond = args["uncond_denoised"] + x_0_cfg = x_0_uncond + cond_scale * (x_0_cond - x_0_uncond) + dims = tuple(range(1, x_0_cond.ndim)) + ro_pos = x_0_cond.std(dim=dims, keepdim=True) + ro_cfg = x_0_cfg.std(dim=dims, keepdim=True).clamp(min=1e-8) + x_0_rescaled = x_0_cfg * (ro_pos / ro_cfg) + x_0_final = multiplier * x_0_rescaled + (1.0 - multiplier) * x_0_cfg + return x_orig - x_0_final + cond = args["cond"] uncond = args["uncond"] - cond_scale = args["cond_scale"] sigma = args["sigma"] sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1)) - x_orig = args["input"] #rescale cfg has to be done on v-pred model output x = x_orig / (sigma * sigma + 1.0) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 698b6b128..3484d6d6d 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,13 +1,60 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types, io from comfy.ldm.trellis2.vae import SparseTensor +from comfy.ldm.trellis2.model import _build_proj_transform_matrix, _project_points_to_image +from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch import comfy.model_management +import comfy.utils +import folder_paths from PIL import Image +import logging import numpy as np +import math import torch ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES") +Pixal3DProjPack = io.Custom("PIXAL3D_PROJ_PACK") +NAFModel = io.Custom("NAF_MODEL") + + +# Pixal3D trains in a 90°-X-rotated grid frame (F_p). We un-rotate decoder outputs for +# user-facing previews/meshes, then re-rotate before feeding coords back to the shape DiT. + +def _pixal3d_unrotate_voxel_data(data: torch.Tensor) -> torch.Tensor: + if data.ndim == 4: + return data.flip(-1).permute(0, 1, 3, 2).contiguous() + if data.ndim == 5: + return data.flip(-1).permute(0, 1, 2, 4, 3).contiguous() + raise ValueError(f"unexpected voxel shape {tuple(data.shape)}") + + +def _pixal3d_rerotate_voxel_data(data: torch.Tensor) -> torch.Tensor: + if data.ndim == 4: + return data.permute(0, 1, 3, 2).flip(-1).contiguous() + if data.ndim == 5: + return data.permute(0, 1, 2, 4, 3).flip(-1).contiguous() + raise ValueError(f"unexpected voxel shape {tuple(data.shape)}") + + +def _pixal3d_unrotate_vertices(vertices: torch.Tensor) -> torch.Tensor: + if vertices.numel() == 0: + return vertices + x, y, z = vertices.unbind(-1) + return torch.stack([-x, y, -z], dim=-1).contiguous() + + +def _pixal3d_unrotate_sparse_coords(coords: torch.Tensor, resolution: int) -> torch.Tensor: + if coords.numel() == 0: + return coords + R1 = resolution - 1 + if coords.shape[-1] == 4: + b, i, j, k = coords.unbind(-1) + return torch.stack([b, R1 - i, j, R1 - k], dim=-1).contiguous() + if coords.shape[-1] == 3: + i, j, k = coords.unbind(-1) + return torch.stack([R1 - i, j, R1 - k], dim=-1).contiguous() + raise ValueError(f"unexpected coord shape {tuple(coords.shape)}") def prepare_trellis_vae_for_decode(vae, sample_shape): memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype) @@ -163,6 +210,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): prepare_trellis_vae_for_decode(vae, sample_tensor.shape) trellis_vae = vae.first_stage_model coord_counts = samples.get("coord_counts") + pixal3d_mode = samples.get("model_options", {}).get("proj_feat_pack") is not None samples = samples["samples"] if coord_counts is None: @@ -188,6 +236,10 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): coords_list = [stage_tensor.coords for stage_tensor in stage_tensors] subs.append(SparseTensor.from_tensor_list(feats_list, coords_list)) + if pixal3d_mode: + for m in mesh: + m.vertices = _pixal3d_unrotate_vertices(m.vertices) + face_list = [m.faces for m in mesh] vert_list = [m.vertices for m in mesh] if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list): @@ -224,6 +276,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): prepare_trellis_vae_for_decode(vae, sample_tensor.shape) trellis_vae = vae.first_stage_model coord_counts = samples.get("coord_counts") + pixal3d_mode = samples.get("model_options", {}).get("proj_feat_pack") is not None samples = samples["samples"] samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts) @@ -237,7 +290,17 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): color_feats = voxel.feats[:, :3] voxel_coords = voxel.coords#[:, 1:] - voxel = Types.VOXEL(voxel_coords, color_feats, 1024) + if voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3: + spatial = voxel_coords[:, -3:] if voxel_coords.shape[-1] == 4 else voxel_coords + max_idx = int(spatial.max().item()) + 1 + tex_resolution = next((r for r in (256, 512, 1024, 1536, 2048) if r >= max_idx), max_idx) + else: + tex_resolution = 1024 + + if pixal3d_mode: + voxel_coords = _pixal3d_unrotate_sparse_coords(voxel_coords, resolution=tex_resolution) + + voxel = Types.VOXEL(voxel_coords, color_feats, tex_resolution) return IO.NodeOutput(voxel) class VaeDecodeStructureTrellis2(IO.ComfyNode): @@ -274,7 +337,10 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): if current_res != resolution: ratio = current_res // resolution decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 - out = Types.VOXEL(decoded.squeeze(1).float()) + voxel_data = decoded.squeeze(1).float() + if samples.get("model_options", {}).get("proj_feat_pack") is not None: + voxel_data = _pixal3d_unrotate_voxel_data(voxel_data) + out = Types.VOXEL(voxel_data) return IO.NodeOutput(out) class Trellis2UpsampleCascade(IO.ComfyNode): @@ -540,7 +606,6 @@ class Trellis2Conditioning(IO.ComfyNode): cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)) cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 else: - import logging logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") cropped_np = rgba_np.astype(np.float32) / 255.0 @@ -587,7 +652,12 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode): "Shape structure input. Accepts either a voxel structure " "or upsampled voxel coordinates from a previous cascade stage." ) - ) + ), + Pixal3DProjPack.Input( + "proj_feat_pack", + optional=True, + tooltip="Pixal3D pixel-aligned projection pack from Pixal3DConditioning. Leave empty for vanilla Trellis2.", + ), ], outputs=[ IO.Latent.Output(), @@ -595,21 +665,26 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode): ) @classmethod - def execute(cls, voxel): - # to accept the upscaled coords + def execute(cls, voxel, proj_feat_pack=None): is_512_pass = False + coord_resolution = None upsampled = hasattr(voxel, "upsampled") if upsampled: + if hasattr(voxel, "resolutions") and voxel.resolutions is not None: + coord_resolution = int(voxel.resolutions[0].item()) // 16 voxel = voxel.data if not upsampled: - decoded = voxel.data.unsqueeze(1) + voxel_data = voxel.data + if proj_feat_pack is not None: + voxel_data = _pixal3d_rerotate_voxel_data(voxel_data) + decoded = voxel_data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() is_512_pass = True + coord_resolution = int(decoded.shape[-1]) else: coords = voxel.int() - is_512_pass = False batch_size, counts, max_tokens = infer_batched_coord_layout(coords) in_channels = 32 @@ -620,8 +695,13 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode): generation_mode = "shape_generation_512" else: generation_mode = "shape_generation" + model_options = {"generation_mode": generation_mode, "coords": coords, "coord_counts": counts} + if coord_resolution is not None: + model_options["coord_resolution"] = coord_resolution + if proj_feat_pack is not None: + model_options["proj_feat_pack"] = proj_feat_pack return IO.NodeOutput({"samples": latent, "coords": coords, "coord_counts": counts, "type": "trellis2", - "model_options": {"generation_mode": generation_mode, "coords": coords, "coord_counts": counts}}) + "model_options": model_options}) class EmptyTrellis2LatentTexture(IO.ComfyNode): @classmethod @@ -638,6 +718,11 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode): ) ), IO.Latent.Input("shape_latent"), + Pixal3DProjPack.Input( + "proj_feat_pack", + optional=True, + tooltip="Pixal3D pixel-aligned projection pack from Pixal3DConditioning. Leave empty for vanilla Trellis2.", + ), ], outputs=[ IO.Latent.Output(), @@ -645,15 +730,22 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode): ) @classmethod - def execute(cls, voxel, shape_latent): + def execute(cls, voxel, shape_latent, proj_feat_pack=None): channels = 32 + coord_resolution = None upsampled = hasattr(voxel, "upsampled") if upsampled: + if hasattr(voxel, "resolutions") and voxel.resolutions is not None: + coord_resolution = int(voxel.resolutions[0].item()) // 16 voxel = voxel.data if not upsampled: - decoded = voxel.data.unsqueeze(1) + voxel_data = voxel.data + if proj_feat_pack is not None: + voxel_data = _pixal3d_rerotate_voxel_data(voxel_data) + decoded = voxel_data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + coord_resolution = int(decoded.shape[-1]) else: coords = voxel.int() @@ -664,9 +756,13 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode): shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels) latent = torch.zeros(batch_size, channels, max_tokens, 1) + model_options = {"generation_mode": "texture_generation", "coords": coords, "coord_counts": counts, "shape_slat": shape_latent} + if coord_resolution is not None: + model_options["coord_resolution"] = coord_resolution + if proj_feat_pack is not None: + model_options["proj_feat_pack"] = proj_feat_pack return IO.NodeOutput({"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts, - "model_options": {"generation_mode": "texture_generation", - "coords": coords, "coord_counts": counts, "shape_slat": shape_latent}}) + "model_options": model_options}) class EmptyTrellis2LatentStructure(IO.ComfyNode): @@ -677,27 +773,441 @@ class EmptyTrellis2LatentStructure(IO.ComfyNode): category="latent/3d", inputs=[ IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), + Pixal3DProjPack.Input( + "proj_feat_pack", + optional=True, + tooltip="Pixal3D pixel-aligned projection pack. Leave empty for vanilla Trellis2.", + ), ], outputs=[ IO.Latent.Output(), ] ) @classmethod - def execute(cls, batch_size): - in_channels = 8 + def execute(cls, batch_size, proj_feat_pack=None): + # Trellis2.forward slices x[:, :8] and pads out to 32; KSampler residual math + # needs the empty latent to match latent_format (32-channel). + in_channels = 32 resolution = 16 latent = torch.zeros(batch_size, in_channels, resolution, resolution, resolution) output = { "samples": latent, "type": "trellis2", } + if proj_feat_pack is not None: + output["model_options"] = {"proj_feat_pack": proj_feat_pack} return IO.NodeOutput(output) +def _dinov3_patches_to_2d(tokens, image_size, patch_size=16): + h_p = w_p = image_size // patch_size + n_patches = h_p * w_p + n_reg = tokens.shape[1] - 1 - n_patches + if n_reg < 0 or tokens.shape[1] != 1 + n_reg + n_patches: + raise ValueError( + f"_dinov3_patches_to_2d: got {tokens.shape[1]} tokens, expected " + f"1 (CLS) + N_reg + {h_p}*{w_p}={n_patches} patches at image_size={image_size}, " + f"patch_size={patch_size}. Inferred N_reg={n_reg} which is invalid." + ) + start = 1 + n_reg + patches = tokens[:, start:start + n_patches] + return patches.transpose(1, 2).reshape(tokens.shape[0], -1, h_p, w_p).contiguous() + + +def _fov_from_moge_intrinsics(moge_intrinsics: torch.Tensor) -> float: + fx = moge_intrinsics[..., 0, 0].float() + fov = 2.0 * torch.atan(0.5 / fx.clamp(min=1e-4)) + return float(fov.mean().item()) + + +def _run_dinov3_with_patches(model, cropped_pil, image_size): + # Pixal3D's cross-attn was trained against CLS + registers only (~5 tokens), not the + # full patch grid. The patch grid goes to the proj branch via patches_2d. + model_internal = model.model + torch_device = comfy.model_management.get_torch_device() + resized = cropped_pil.resize((image_size, image_size), Image.Resampling.LANCZOS) + img_np = np.array(resized).astype(np.float32) / 255.0 + img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device) + img_t = (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device) + model_internal.image_size = image_size + tokens = model_internal(img_t, skip_norm_elementwise=True)[0] + patches = _dinov3_patches_to_2d(tokens, image_size) + h_p = w_p = image_size // 16 + n_reg = tokens.shape[1] - 1 - h_p * w_p + global_tokens = tokens[:, :1 + n_reg] + return {"tokens": global_tokens, "patches_2d": patches} + + +def _crop_image_with_mask(item_image, item_mask, max_image_size=1024): + img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + pil_img = Image.fromarray(img_np) + pil_mask = Image.fromarray(mask_np) + max_size = max(pil_img.size) + scale = min(1.0, max_image_size / max_size) + if scale < 1.0: + new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale) + pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS) + pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST) + scene_size = (pil_img.width, pil_img.height) + rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8) + rgba_np[:, :, :3] = np.array(pil_img) + rgba_np[:, :, 3] = np.array(pil_mask) + alpha = rgba_np[:, :, 3] + bbox_coords = np.argwhere(alpha > 0.8 * 255) + if len(bbox_coords) > 0: + y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1]) + y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1]) + center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0 + # Upstream pads the bbox by 10% — encoders were trained with that breathing room. + size = max(y_max - y_min, x_max - x_min) + size = int(size * 1.1) + half = size // 2 + crop_x1 = int(center_x - half) + crop_y1 = int(center_y - half) + crop_x2 = crop_x1 + 2 * half + crop_y2 = crop_y1 + 2 * half + crop_bbox = (crop_x1, crop_y1, crop_x2, crop_y2) + rgba_pil = Image.fromarray(rgba_np) + cropped_rgba = rgba_pil.crop(crop_bbox) + cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 + else: + logging.warning("Mask for the image is empty. Pixal3D requires a clean foreground mask.") + cropped_np = rgba_np.astype(np.float32) / 255.0 + crop_bbox = (0, 0, scene_size[0], scene_size[1]) + fg = cropped_np[:, :, :3] + alpha_float = cropped_np[:, :, 3:4] + composite_np = fg * alpha_float + composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) + return Image.fromarray(composite_uint8), crop_bbox, scene_size + + +class Pixal3DConditioning(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Pixal3DConditioning", + category="conditioning/video_models", + inputs=[ + IO.ClipVision.Input("clip_vision_model", tooltip="DINOv3 ViT-L/16 ClipVision."), + IO.Image.Input("image"), + IO.Mask.Input("mask"), + IO.Float.Input( + "camera_angle_x", default=0.2, min=0.0175, max=2.9671, step=0.001, + tooltip="Horizontal FOV in radians (upstream demo default 0.2). " + "Overridden by moge_geometry if connected.", + ), + IO.Float.Input( + "mesh_scale", default=1.0, min=0.1, max=4.0, step=0.01, + tooltip="Mesh scale; 1.0 means unit cube.", + ), + IO.Float.Input( + "distance_override", default=0.0, min=0.0, max=10.0, step=0.001, + tooltip="Override camera distance directly. 0 = auto-derive from FOV.", + ), + io.Custom("MOGE_GEOMETRY").Input( + "moge_geometry", + optional=True, + tooltip="If connected, camera_angle_x is recovered from MoGe.", + ), + NAFModel.Input( + "naf_model", + optional=True, + tooltip="Optional NAF feature upsampler. Required for shape/texture stages " + "to match upstream's trained feature distribution.", + ), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + Pixal3DProjPack.Output(display_name="proj_feat_pack"), + ], + ) + + @classmethod + def execute(cls, clip_vision_model, image, mask, camera_angle_x, mesh_scale, + distance_override=0.0, + moge_geometry=None, naf_model=None) -> IO.NodeOutput: + if image.ndim == 3: + image = image.unsqueeze(0) + if mask.ndim == 2: + mask = mask.unsqueeze(0) + batch_size = image.shape[0] + if mask.shape[0] == 1 and batch_size > 1: + mask = mask.expand(batch_size, -1, -1) + elif mask.shape[0] != batch_size: + raise ValueError(f"Pixal3DConditioning mask batch {mask.shape[0]} != image batch {batch_size}") + + if moge_geometry is not None and "intrinsics" in moge_geometry: + camera_angle_x = _fov_from_moge_intrinsics(moge_geometry["intrinsics"]) + + device = comfy.model_management.intermediate_device() + + cond_512_list, cond_1024_list = [], [] + patches_512_list, patches_1024_list = [], [] + cropped_pil_list = [] + crop_bbox_list, scene_size_list = [], [] + + torch_device = comfy.model_management.get_torch_device() + for b in range(batch_size): + item_image = image[b] + item_mask = mask[b] if mask.size(0) > 1 else mask[0] + cropped_pil, crop_bbox, scene_size = _crop_image_with_mask( + item_image, item_mask, max_image_size=1024) + crop_bbox_list.append(crop_bbox) + scene_size_list.append(scene_size) + cropped_pil_list.append(cropped_pil) + + cond_512 = _run_dinov3_with_patches(clip_vision_model, cropped_pil, 512) + cond_1024 = _run_dinov3_with_patches(clip_vision_model, cropped_pil, 1024) + cond_512_list.append(cond_512["tokens"].to(device)) + cond_1024_list.append(cond_1024["tokens"].to(device)) + patches_512_list.append(cond_512["patches_2d"].to(device)) + patches_1024_list.append(cond_1024["patches_2d"].to(device)) + + global_512 = torch.cat(cond_512_list, dim=0) + global_1024 = torch.cat(cond_1024_list, dim=0) + + fm_512_dino = torch.cat(patches_512_list, dim=0) + fm_1024_dino = torch.cat(patches_1024_list, dim=0) + + # Upstream samples the LR DINO grid AND the NAF HR grid separately at projected + # 3D points, then cats sampled features along channels. Back-projection (in model.py) + # mirrors that — here we just stash LR + optional HR per stage. + # NAF targets per stage: shape_512=512, shape_1024=512, tex_1024=1024. + def _naf_hr(lr_feat, image_pil_list, image_size, naf_target): + if naf_model is None or naf_target is None: + return None + # Run NAF in the input feature dtype (typically fp16 since DINO/ClipVision + # loads that way). The previous .float() cast doubled NAF memory by forcing + # full fp32 — at tex_1024/target=1024 that's ~10 GB on its own. Model + # weights need to match input dtype since PyTorch conv ops error out on + # mixed fp16-input/fp32-weight. + target_dtype = lr_feat.dtype + if next(naf_model.parameters()).dtype != target_dtype: + naf_model.to(dtype=target_dtype) + imgs = torch.stack([ + torch.from_numpy( + np.array(p.resize((image_size, image_size), Image.Resampling.LANCZOS)) + .astype(np.float32) / 255.0 + ).permute(2, 0, 1) + for p in image_pil_list + ], dim=0).to(torch_device).to(target_dtype) + + hr = naf_model(imgs, lr_feat.to(torch_device).to(target_dtype), naf_target) + return hr.to(device) + + hr_shape_512 = _naf_hr(fm_512_dino, cropped_pil_list, 512, (512, 512)) + hr_shape_1024 = _naf_hr(fm_1024_dino, cropped_pil_list, 1024, (512, 512)) + hr_tex_1024 = _naf_hr(fm_1024_dino, cropped_pil_list, 1024, (1024, 1024)) + + # distance_from_fov: grid_point (-1, 0, 0) projects to pixel (0, image_resolution-1). + camera_angle_x = float(camera_angle_x) + if distance_override > 0: + distance = float(distance_override) + else: + distance = 0.5 / math.tan(camera_angle_x / 2.0) / float(mesh_scale) + cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=device, dtype=torch.float32) + dist_t = torch.tensor([distance] * batch_size, device=device, dtype=torch.float32) + scale_t = torch.tensor([float(mesh_scale)] * batch_size, device=device, dtype=torch.float32) + T = _build_proj_transform_matrix(dist_t, batch_size, device=device, dtype=torch.float32) + + proj_pack = { + "stages": { + "ss": {"feature_map": fm_512_dino, "feature_map_hr": None, "image_resolution": 512}, + "shape_512": {"feature_map": fm_512_dino, "feature_map_hr": hr_shape_512, "image_resolution": 512}, + "shape_1024": {"feature_map": fm_1024_dino, "feature_map_hr": hr_shape_1024,"image_resolution": 1024}, + "tex_1024": {"feature_map": fm_1024_dino, "feature_map_hr": hr_tex_1024, "image_resolution": 1024}, + }, + "transform_matrix": T, + "camera_angle_x": cam_angle_t, + "mesh_scale": scale_t, + "distance": dist_t, + "patch_size": 16, + "crop_bboxes": crop_bbox_list, + "scene_sizes": scene_size_list, + } + + # global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024 + # (Trellis2.forward swaps context↔embeds for non-structure HR stages). + neg_global = torch.zeros_like(global_512) + neg_embeds = torch.zeros_like(global_1024) + positive = [[global_512, {"embeds": global_1024}]] + negative = [[neg_global, {"embeds": neg_embeds}]] + return IO.NodeOutput(positive, negative, proj_pack) + + +def _project_vertices_to_image_uv(vertices_world, transform_matrix, camera_angle_x, image_resolution): + points = vertices_world.unsqueeze(0).float() + T = transform_matrix.unsqueeze(0).float() if transform_matrix.ndim == 2 else transform_matrix.float() + cam = camera_angle_x.unsqueeze(0) if camera_angle_x.ndim == 0 else camera_angle_x + uv_pix, depth, valid = _project_points_to_image(points, T, cam.float(), image_resolution) + uv = uv_pix.squeeze(0) / image_resolution + return uv, depth.squeeze(0), valid.squeeze(0) + + +def _crop_uv_to_scene_pixels(uv_crop, crop_bbox, scene_image_size): + crop_x1, crop_y1, crop_x2, crop_y2 = crop_bbox + crop_w = max(1, crop_x2 - crop_x1) + crop_h = max(1, crop_y2 - crop_y1) + px = uv_crop[:, 0] * crop_w + crop_x1 + py = uv_crop[:, 1] * crop_h + crop_y1 + W, H = scene_image_size + return torch.stack([px.clamp(0, W - 1), py.clamp(0, H - 1)], dim=-1) + + +class Pixal3DAlignObject(IO.ComfyNode): + """Pixal3D paper §3.3 Global Alignment for a single object. + + Solves (scale, translation) aligning the mesh to MoGe's per-pixel point map. Requires + MoGe to have been computed on the same resized scene image as Pixal3DConditioning.""" + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Pixal3DAlignObject", + category="latent/3d", + inputs=[ + IO.Mesh.Input("mesh"), + Pixal3DProjPack.Input("proj_feat_pack", tooltip="The proj pack produced by Pixal3DConditioning for this object."), + io.Custom("MOGE_GEOMETRY").Input("moge_geometry", tooltip="MoGe geometry computed on the original scene image."), + IO.Mask.Input( + "object_mask", + optional=True, + tooltip="Optional per-object scene-space mask. If connected, only vertices whose projected pixel falls inside the mask contribute to the alignment solve.", + ), + IO.Int.Input( + "batch_index", + default=0, min=0, max=1024, + tooltip="Which batch slot of the proj_feat_pack/MoGe geometry corresponds to this object.", + ), + ], + outputs=[ + IO.Mesh.Output("aligned_mesh"), + IO.Float.Output(display_name="scale"), + ], + ) + + @classmethod + def execute(cls, mesh, proj_feat_pack, moge_geometry, object_mask=None, batch_index=0) -> IO.NodeOutput: + vertices = mesh.vertices + faces = mesh.faces + if vertices.ndim == 3: + vertices_one = vertices[0] + faces_one = faces[0] + else: + vertices_one = vertices + faces_one = faces + + T = proj_feat_pack["transform_matrix"][batch_index:batch_index + 1] + cam_angle = proj_feat_pack["camera_angle_x"][batch_index:batch_index + 1] + mesh_scale = proj_feat_pack["mesh_scale"][batch_index] + image_resolution = int(proj_feat_pack.get("image_resolution", 1024)) + crop_bbox = proj_feat_pack["crop_bboxes"][batch_index] + pack_scene_size = proj_feat_pack.get("scene_sizes", [None] * (batch_index + 1))[batch_index] + moge_points = moge_geometry["points"] + moge_mask = moge_geometry["mask"] + if moge_points.ndim != 4: + raise ValueError(f"MoGe points expected [B, H, W, 3]; got {tuple(moge_points.shape)}") + scene_H, scene_W = moge_points.shape[1], moge_points.shape[2] + if pack_scene_size is not None and pack_scene_size != (scene_W, scene_H): + raise ValueError( + f"Pixal3DAlignObject: MoGe geometry was computed on a {scene_W}x{scene_H} image, " + f"but the proj_feat_pack's bbox lives in a {pack_scene_size[0]}x{pack_scene_size[1]} " + "image. Run MoGe on the same resized scene image Pixal3DConditioning used." + ) + + # Compose VaeDecodeShapeTrellis's R_y(180°) inverse with R_proj to map user mesh + # space to ProjGrid world: (X, Y, Z) -> (-X, Z, Y). + v = vertices_one.float() + verts_world = torch.stack([-v[..., 0], v[..., 2], v[..., 1]], dim=-1) + verts_world = verts_world / float(mesh_scale.item()) + uv_crop, _depth, valid = _project_vertices_to_image_uv( + verts_world, T[0], cam_angle[0], image_resolution) + scene_pixels = _crop_uv_to_scene_pixels(uv_crop, crop_bbox, (scene_W, scene_H)) + in_scene = ((scene_pixels[:, 0] >= 0) & (scene_pixels[:, 0] < scene_W) & + (scene_pixels[:, 1] >= 0) & (scene_pixels[:, 1] < scene_H)) + sx = scene_pixels[:, 0].long().clamp(0, scene_W - 1) + sy = scene_pixels[:, 1].long().clamp(0, scene_H - 1) + moge_per_vertex = moge_points[batch_index, sy, sx] + moge_mask_per_vertex = moge_mask[batch_index, sy, sx] + keep = valid & in_scene & moge_mask_per_vertex + if object_mask is not None: + om = object_mask if object_mask.ndim == 2 else object_mask[batch_index] + keep = keep & (om[sy, sx] > 0.5) + + finite = torch.isfinite(moge_per_vertex).all(dim=-1) + keep = keep & finite + + kept = int(keep.sum().item()) + if kept < 8: + scale = 1.0 + aligned = vertices_one + else: + P = vertices_one[keep].float() + Q = moge_per_vertex[keep].float() + p_mean = P.mean(dim=0, keepdim=True) + q_mean = Q.mean(dim=0, keepdim=True) + P_c = P - p_mean + Q_c = Q - q_mean + num = (P_c * Q_c).sum() + den = (P_c * P_c).sum().clamp(min=1e-8) + scale = float((num / den).item()) + if not (scale > 0): + # Negative scale would mirror the mesh; treat as a camera-convention mismatch. + logging.warning( + f"Pixal3DAlignObject: computed scale={scale:.4f} <= 0; " + "refusing to apply mirroring. Check camera convention alignment.") + scale = 1.0 + aligned = vertices_one + else: + t = q_mean - scale * p_mean + aligned = scale * vertices_one + t + + if vertices.ndim == 3: + aligned = aligned.unsqueeze(0) + out_mesh = Types.MESH(vertices=aligned, faces=faces) + else: + out_mesh = Types.MESH(vertices=aligned, faces=faces_one) + return IO.NodeOutput(out_mesh, float(scale)) + + +class LoadNAFModel(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="LoadNAFModel", + display_name="Load NAF Model", + category="loaders", + inputs=[ + IO.Combo.Input( + "naf_name", + options=folder_paths.get_filename_list("upscale_models"), + tooltip="NAF safetensors checkpoint (e.g. naf_release.safetensors).", + ), + ], + outputs=[NAFModel.Output(display_name="naf_model")], + ) + + @classmethod + def execute(cls, naf_name) -> IO.NodeOutput: + path = folder_paths.get_full_path_or_raise("upscale_models", naf_name) + sd = comfy.utils.load_torch_file(path, safe_load=True) + model = build_naf_from_state_dict(sd) + device = comfy.model_management.get_torch_device() + model = model.to(device).eval() + return IO.NodeOutput(model) + + class Trellis2Extension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ Trellis2Conditioning, + Pixal3DConditioning, + Pixal3DAlignObject, + LoadNAFModel, EmptyTrellis2ShapeLatent, EmptyTrellis2LatentStructure, EmptyTrellis2LatentTexture,