mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
initial Pixal3D support
This commit is contained in:
parent
5b981ed295
commit
9e4794da5c
@ -5,8 +5,12 @@ from typing import Tuple, Union, List
|
||||
from comfy.ldm.trellis2.vae import VarLenTensor
|
||||
import comfy.ops
|
||||
|
||||
try:
|
||||
from torch.nn.attention.varlen import varlen_attn as _varlen_attn
|
||||
except ImportError:
|
||||
_varlen_attn = None
|
||||
|
||||
|
||||
# replica of the seedvr2 code
|
||||
def var_attn_arg(kwargs):
|
||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||
@ -16,42 +20,30 @@ def var_attn_arg(kwargs):
|
||||
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = True
|
||||
if var_length:
|
||||
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
|
||||
if not skip_reshape:
|
||||
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||
total_tokens, embed_dim = q.shape
|
||||
head_dim = embed_dim // heads
|
||||
q = q.view(total_tokens, heads, head_dim)
|
||||
k = k.view(k.shape[0], heads, head_dim)
|
||||
v = v.view(v.shape[0], heads, head_dim)
|
||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
|
||||
if not skip_reshape:
|
||||
total_tokens, embed_dim = q.shape
|
||||
head_dim = embed_dim // heads
|
||||
q = q.view(total_tokens, heads, head_dim)
|
||||
k = k.view(k.shape[0], heads, head_dim)
|
||||
v = v.view(v.shape[0], heads, head_dim)
|
||||
|
||||
b = q.size(0)
|
||||
dim_head = q.shape[-1]
|
||||
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||
|
||||
mask = None
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if var_length:
|
||||
return out.transpose(1, 2).values()
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
if _varlen_attn is not None:
|
||||
return _varlen_attn(
|
||||
q, k, v,
|
||||
cu_seqlens_q, cu_seqlens_k,
|
||||
int(max_seqlen_q), int(max_seqlen_k),
|
||||
)
|
||||
return out
|
||||
|
||||
# Fallback: nested-tensor SDPA (PyTorch < the version that introduced varlen_attn)
|
||||
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
return out.transpose(1, 2).values()
|
||||
|
||||
def scaled_dot_product_attention(*args, **kwargs):
|
||||
num_all_args = len(args) + len(kwargs)
|
||||
|
||||
@ -26,16 +26,26 @@ class TorchHashMap:
|
||||
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
|
||||
self._n = self.sorted_keys.numel()
|
||||
|
||||
# Chunk size for lookup_flat. At ~530M flat keys (large mesh extraction),
|
||||
# the unchunked path allocates ~5 full-size int64 temporaries (4 GB each) +
|
||||
# bool masks + the int32 output. Chunking caps each transient to ~CHUNK rows.
|
||||
_LOOKUP_CHUNK = 1 << 23 # 8M rows ≈ 64 MB per int64 temp
|
||||
|
||||
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
|
||||
flat = flat_keys.to(torch.long)
|
||||
if self._n == 0:
|
||||
return torch.full((flat.shape[0],), -1, device=flat.device, dtype=torch.int32)
|
||||
idx = torch.searchsorted(self.sorted_keys, flat)
|
||||
idx_safe = torch.clamp(idx, max=self._n - 1)
|
||||
found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat)
|
||||
out = torch.full((flat.shape[0],), -1, device=flat.device, dtype=torch.int32)
|
||||
if found.any():
|
||||
out[found] = self.sorted_vals[idx_safe[found]].to(torch.int32)
|
||||
N = flat_keys.shape[0]
|
||||
out = torch.full((N,), -1, device=flat_keys.device, dtype=torch.int32)
|
||||
if self._n == 0 or N == 0:
|
||||
return out
|
||||
for s in range(0, N, self._LOOKUP_CHUNK):
|
||||
e = min(s + self._LOOKUP_CHUNK, N)
|
||||
flat_chunk = flat_keys[s:e].to(torch.long)
|
||||
idx = torch.searchsorted(self.sorted_keys, flat_chunk)
|
||||
in_range = idx < self._n
|
||||
idx.clamp_(max=self._n - 1) # reuse idx as the "safe" index
|
||||
found = in_range & (self.sorted_keys[idx] == flat_chunk)
|
||||
if found.any():
|
||||
found_idx = found.nonzero(as_tuple=True)[0]
|
||||
out[s + found_idx] = self.sorted_vals[idx[found_idx]].to(torch.int32)
|
||||
return out
|
||||
|
||||
|
||||
@ -212,10 +222,10 @@ def sparse_submanifold_conv3d(
|
||||
|
||||
if accumulate_f32:
|
||||
weight_T = weight.view(Co, V * Ci).to(torch.float32).T.contiguous()
|
||||
output = torch.zeros(N_pts, Co, device=device, dtype=torch.float32)
|
||||
else:
|
||||
weight_T = weight.view(Co, V * Ci).to(feats.dtype).T.contiguous()
|
||||
output = torch.zeros(N_pts, Co, device=device, dtype=feats.dtype)
|
||||
|
||||
output = torch.empty(N_pts, Co, device=device, dtype=feats.dtype)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Chunk size from memory budget
|
||||
@ -226,6 +236,9 @@ def sparse_submanifold_conv3d(
|
||||
chunk_size = max(1, int(max_chunk_mem / mem_per_row))
|
||||
chunk_size = min(chunk_size, N_pts)
|
||||
|
||||
# fp32 matmul scratch — sized to the largest chunk, reused each iteration.
|
||||
chunk_buf = torch.empty(chunk_size, Co, device=device, dtype=torch.float32) if accumulate_f32 else None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Chunked forward pass
|
||||
# Each iteration:
|
||||
@ -233,7 +246,8 @@ def sparse_submanifold_conv3d(
|
||||
# 2. mask zero invalids – in-place, no extra alloc
|
||||
# 3. reshape (chunk, V*Ci)
|
||||
# 4. GEMM (chunk, V*Ci) @ (V*Ci, Co) → (chunk, Co) – cuBLAS
|
||||
# written directly into output slice via out= argument
|
||||
# written into the scratch buf (fp32) or output slice (fp16) via out=
|
||||
# 5. (fp32 path) cast scratch chunk to fp16 and copy into output slice
|
||||
# ------------------------------------------------------------------
|
||||
for start in range(0, N_pts, chunk_size):
|
||||
end = min(start + chunk_size, N_pts)
|
||||
@ -257,16 +271,13 @@ def sparse_submanifold_conv3d(
|
||||
gathered_flat = gathered.view(actual_chunk, V * Ci)
|
||||
if accumulate_f32:
|
||||
gathered_flat = gathered_flat.to(torch.float32)
|
||||
|
||||
# Single GEMM call per chunk, written directly into output.
|
||||
# This avoids allocating a temporary (chunk, Co) tensor.
|
||||
torch.matmul(gathered_flat, weight_T, out=output[start:end])
|
||||
|
||||
if accumulate_f32:
|
||||
output = output.to(feats.dtype)
|
||||
torch.matmul(gathered_flat, weight_T, out=chunk_buf[:actual_chunk])
|
||||
output[start:end] = chunk_buf[:actual_chunk].to(feats.dtype)
|
||||
else:
|
||||
torch.matmul(gathered_flat, weight_T, out=output[start:end])
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias.unsqueeze(0).to(output.dtype)
|
||||
output += bias.unsqueeze(0).to(output.dtype)
|
||||
|
||||
return output, neighbor
|
||||
|
||||
|
||||
@ -25,15 +25,12 @@ class SparseFeedForwardNet(nn.Module):
|
||||
def forward(self, x: VarLenTensor) -> VarLenTensor:
|
||||
return self.mlp(x)
|
||||
|
||||
def manual_cast(obj, dtype):
|
||||
return obj.to(dtype=dtype)
|
||||
|
||||
class LayerNorm32(nn.LayerNorm):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_dtype = x.dtype
|
||||
x = manual_cast(x, torch.float32)
|
||||
x = x.to(dtype=torch.float32)
|
||||
o = super().forward(x)
|
||||
return manual_cast(o, x_dtype)
|
||||
return o.to(dtype=x_dtype)
|
||||
|
||||
|
||||
class SparseMultiHeadRMSNorm(nn.Module):
|
||||
@ -249,6 +246,51 @@ class SparseMultiHeadAttention(nn.Module):
|
||||
h = self._linear(self.to_out, h)
|
||||
return h
|
||||
|
||||
def _split_proj_context(context):
|
||||
if not isinstance(context, dict):
|
||||
return context, None
|
||||
global_ctx = context["global"]
|
||||
if "proj" in context:
|
||||
return global_ctx, context["proj"]
|
||||
if "proj_semantic" in context and "proj_color" in context:
|
||||
return global_ctx, (context["proj_semantic"], context["proj_color"])
|
||||
return global_ctx, None
|
||||
|
||||
|
||||
class ProjectAttentionSparse(nn.Module):
|
||||
def __init__(self, cross_attn_block: nn.Module, channels: int, proj_in_channels: int,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.cross_attn_block = cross_attn_block
|
||||
self.proj_linear = operations.Linear(proj_in_channels, channels, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: SparseTensor, context) -> SparseTensor:
|
||||
global_ctx, proj_in = _split_proj_context(context)
|
||||
global_out = self.cross_attn_block(x, global_ctx)
|
||||
if isinstance(proj_in, tuple):
|
||||
proj_in = torch.cat([proj_in[0], proj_in[1]], dim=-1)
|
||||
proj_out = self.proj_linear(proj_in.to(self.proj_linear.weight.dtype))
|
||||
return global_out.replace(global_out.feats + proj_out.to(global_out.feats.dtype))
|
||||
|
||||
|
||||
class ProjectAttentionDense(nn.Module):
|
||||
def __init__(self, cross_attn_block: nn.Module, channels: int, proj_in_channels: int,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.cross_attn_block = cross_attn_block
|
||||
self.proj_linear = operations.Linear(proj_in_channels, channels, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, context) -> torch.Tensor:
|
||||
global_ctx, proj_in = _split_proj_context(context)
|
||||
global_out = self.cross_attn_block(x, global_ctx)
|
||||
if isinstance(proj_in, tuple):
|
||||
proj_in = torch.cat([proj_in[0], proj_in[1]], dim=-1)
|
||||
proj_out = self.proj_linear(proj_in.to(self.proj_linear.weight.dtype))
|
||||
return global_out + proj_out.to(global_out.dtype)
|
||||
|
||||
|
||||
class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
"""
|
||||
Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
|
||||
@ -269,11 +311,14 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
qk_rms_norm_cross: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
share_mod: bool = False,
|
||||
image_attn_mode: Literal["global", "proj", "gated_proj"] = "global",
|
||||
proj_in_channels: Optional[int] = None,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.image_attn_mode = image_attn_mode
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
|
||||
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||
@ -290,7 +335,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
self.cross_attn = SparseMultiHeadAttention(
|
||||
cross_inner = SparseMultiHeadAttention(
|
||||
channels,
|
||||
ctx_channels=ctx_channels,
|
||||
num_heads=num_heads,
|
||||
@ -300,6 +345,15 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
qk_rms_norm=qk_rms_norm_cross,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
if image_attn_mode == "global":
|
||||
self.cross_attn = cross_inner
|
||||
else:
|
||||
if proj_in_channels is None:
|
||||
raise ValueError("proj_in_channels must be set when image_attn_mode != 'global'")
|
||||
self.cross_attn = ProjectAttentionSparse(
|
||||
cross_inner, channels, proj_in_channels,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
self.mlp = SparseFeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
@ -313,7 +367,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
else:
|
||||
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
||||
|
||||
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
|
||||
def _forward(self, x: SparseTensor, mod: torch.Tensor, context) -> SparseTensor:
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
|
||||
else:
|
||||
@ -324,7 +378,11 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
h = h * gate_msa
|
||||
x = x + h
|
||||
h = x.replace(self.norm2(x.feats))
|
||||
h = self.cross_attn(h, context)
|
||||
if self.image_attn_mode == "global":
|
||||
global_ctx, _ = _split_proj_context(context)
|
||||
h = self.cross_attn(h, global_ctx)
|
||||
else:
|
||||
h = self.cross_attn(h, context)
|
||||
x = x + h
|
||||
h = x.replace(self.norm3(x.feats))
|
||||
h = h * (1 + scale_mlp) + shift_mlp
|
||||
@ -333,7 +391,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
|
||||
def forward(self, x: SparseTensor, mod: torch.Tensor, context) -> SparseTensor:
|
||||
return self._forward(x, mod, context)
|
||||
|
||||
|
||||
@ -356,6 +414,8 @@ class SLatFlowModel(nn.Module):
|
||||
initialization: str = 'vanilla',
|
||||
qk_rms_norm: bool = False,
|
||||
qk_rms_norm_cross: bool = False,
|
||||
image_attn_mode: Literal["global", "proj", "gated_proj"] = "global",
|
||||
proj_in_channels: Optional[int] = None,
|
||||
dtype = None,
|
||||
device = None,
|
||||
operations = None,
|
||||
@ -375,6 +435,8 @@ class SLatFlowModel(nn.Module):
|
||||
self.initialization = initialization
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||
self.image_attn_mode = image_attn_mode
|
||||
self.proj_in_channels = proj_in_channels
|
||||
self.dtype = dtype
|
||||
|
||||
self.t_embedder = TimestepEmbedder(model_channels, device=device, dtype=dtype, operations=operations)
|
||||
@ -399,6 +461,8 @@ class SLatFlowModel(nn.Module):
|
||||
share_mod=self.share_mod,
|
||||
qk_rms_norm=self.qk_rms_norm,
|
||||
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||
image_attn_mode=image_attn_mode,
|
||||
proj_in_channels=proj_in_channels,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
@ -426,19 +490,15 @@ class SLatFlowModel(nn.Module):
|
||||
dtype = next(self.input_layer.parameters()).dtype
|
||||
x = x.to(dtype)
|
||||
h = self.input_layer(x)
|
||||
h = manual_cast(h, self.dtype)
|
||||
t = t.to(dtype)
|
||||
t_embedder = self.t_embedder.to(dtype)
|
||||
t_emb = t_embedder(t, out_dtype = t.dtype)
|
||||
if self.share_mod:
|
||||
t_emb = self.adaLN_modulation(t_emb)
|
||||
t_emb = manual_cast(t_emb, self.dtype)
|
||||
cond = manual_cast(cond, self.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
h = block(h, t_emb, cond)
|
||||
|
||||
h = manual_cast(h, x.dtype)
|
||||
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
||||
h = self.out_layer(h)
|
||||
return h
|
||||
@ -561,11 +621,14 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
||||
qk_rms_norm_cross: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
share_mod: bool = False,
|
||||
image_attn_mode: Literal["global", "proj", "gated_proj"] = "global",
|
||||
proj_in_channels: Optional[int] = None,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.image_attn_mode = image_attn_mode
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
|
||||
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||
@ -582,7 +645,7 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
self.cross_attn = MultiHeadAttention(
|
||||
cross_inner = MultiHeadAttention(
|
||||
channels,
|
||||
ctx_channels=ctx_channels,
|
||||
num_heads=num_heads,
|
||||
@ -592,6 +655,15 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
||||
qk_rms_norm=qk_rms_norm_cross,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
if image_attn_mode == "global":
|
||||
self.cross_attn = cross_inner
|
||||
else:
|
||||
if proj_in_channels is None:
|
||||
raise ValueError("proj_in_channels must be set when image_attn_mode != 'global'")
|
||||
self.cross_attn = ProjectAttentionDense(
|
||||
cross_inner, channels, proj_in_channels,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
self.mlp = FeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
@ -605,7 +677,7 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
||||
else:
|
||||
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
||||
|
||||
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
|
||||
else:
|
||||
@ -616,7 +688,11 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
||||
h = h * gate_msa.unsqueeze(1)
|
||||
x = x + h
|
||||
h = self.norm2(x)
|
||||
h = self.cross_attn(h, context)
|
||||
if self.image_attn_mode == "global":
|
||||
global_ctx, _ = _split_proj_context(context)
|
||||
h = self.cross_attn(h, global_ctx)
|
||||
else:
|
||||
h = self.cross_attn(h, context)
|
||||
x = x + h
|
||||
h = self.norm3(x)
|
||||
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
||||
@ -625,7 +701,7 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, mod: torch.Tensor, context, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return self._forward(x, mod, context, phases)
|
||||
|
||||
|
||||
@ -648,6 +724,8 @@ class SparseStructureFlowModel(nn.Module):
|
||||
initialization: str = 'vanilla',
|
||||
qk_rms_norm: bool = False,
|
||||
qk_rms_norm_cross: bool = False,
|
||||
image_attn_mode: Literal["global", "proj", "gated_proj"] = "global",
|
||||
proj_in_channels: Optional[int] = None,
|
||||
operations=None,
|
||||
device = None,
|
||||
dtype = torch.float32,
|
||||
@ -669,6 +747,8 @@ class SparseStructureFlowModel(nn.Module):
|
||||
self.initialization = initialization
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||
self.image_attn_mode = image_attn_mode
|
||||
self.proj_in_channels = proj_in_channels
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
@ -703,6 +783,8 @@ class SparseStructureFlowModel(nn.Module):
|
||||
share_mod=share_mod,
|
||||
qk_rms_norm=self.qk_rms_norm,
|
||||
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||
image_attn_mode=image_attn_mode,
|
||||
proj_in_channels=proj_in_channels,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
@ -720,14 +802,9 @@ class SparseStructureFlowModel(nn.Module):
|
||||
t_emb = self.t_embedder(t, out_dtype = t.dtype)
|
||||
if self.share_mod:
|
||||
t_emb = self.adaLN_modulation(t_emb)
|
||||
t_emb = manual_cast(t_emb, self.dtype)
|
||||
h = manual_cast(h, self.dtype)
|
||||
cond = manual_cast(cond, self.dtype)
|
||||
for block in self.blocks:
|
||||
h = block(h, t_emb, cond, self.rope_phases)
|
||||
h = manual_cast(h, x.dtype)
|
||||
h = F.layer_norm(h, h.shape[-1:])
|
||||
h = h.to(next(self.out_layer.parameters()).dtype)
|
||||
h = self.out_layer(h)
|
||||
|
||||
h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous()
|
||||
@ -741,6 +818,221 @@ def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
|
||||
t_new *= 1000.0
|
||||
return t_new
|
||||
|
||||
|
||||
# Pixal3D ProjGrid math — port of upstream's ProjGrid + project_points_to_image_batch.
|
||||
# World frame uses world Y as depth (Blender convention), camera looks along -Z local;
|
||||
# transform_matrix is camera-to-world (inverted internally). Intrinsics: fx = 16 / tan(fov/2)
|
||||
# with sensor_width = 32mm.
|
||||
|
||||
_PROJ_GRID_ROTATION = torch.tensor(
|
||||
[[1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, -1.0],
|
||||
[0.0, 1.0, 0.0]]
|
||||
)
|
||||
|
||||
_PROJ_FRONT_VIEW_TRANSFORM = torch.tensor(
|
||||
[[1.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, -1.0, -2.0],
|
||||
[0.0, 1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 1.0]]
|
||||
)
|
||||
|
||||
|
||||
def _build_proj_transform_matrix(distance: torch.Tensor, batch_size: int,
|
||||
device, dtype=torch.float32) -> torch.Tensor:
|
||||
T = _PROJ_FRONT_VIEW_TRANSFORM.to(device=device, dtype=dtype)
|
||||
T = T.unsqueeze(0).expand(batch_size, -1, -1).clone()
|
||||
if distance.ndim == 0:
|
||||
distance = distance.expand(batch_size)
|
||||
T[:, 1, 3] = -distance.to(device=device, dtype=dtype)
|
||||
return T
|
||||
|
||||
|
||||
def _project_points_to_image(points_world: torch.Tensor, transform_matrix: torch.Tensor,
|
||||
camera_angle_x: torch.Tensor, resolution: int):
|
||||
B, N, _ = points_world.shape
|
||||
ones = torch.ones((B, N, 1), device=points_world.device, dtype=points_world.dtype)
|
||||
homo = torch.cat([points_world, ones], dim=-1)
|
||||
world_to_camera = torch.linalg.inv(transform_matrix.float()).to(transform_matrix.dtype)
|
||||
p_cam = torch.bmm(homo, world_to_camera.transpose(-2, -1))[..., :3]
|
||||
x_cam, y_cam, z_cam = p_cam.unbind(dim=-1)
|
||||
depth = -z_cam
|
||||
sensor_width = 32.0
|
||||
focal_length = 16.0 / torch.tan(camera_angle_x / 2.0)
|
||||
focal_px = focal_length * resolution / sensor_width
|
||||
focal_px = focal_px.to(p_cam.dtype).unsqueeze(1)
|
||||
denom = (-z_cam + 1e-8)
|
||||
x_pix = focal_px * x_cam / denom + resolution / 2.0
|
||||
y_pix = -focal_px * y_cam / denom + resolution / 2.0
|
||||
valid = ((x_pix >= 0) & (x_pix < resolution) &
|
||||
(y_pix >= 0) & (y_pix < resolution) & (depth > 0))
|
||||
return torch.stack([x_pix, y_pix], dim=-1), depth, valid
|
||||
|
||||
|
||||
def _sample_features(feature_map: torch.Tensor, uv_ndc: torch.Tensor) -> torch.Tensor:
|
||||
B, C, _, _ = feature_map.shape
|
||||
grid = uv_ndc.view(B, -1, 1, 2).to(feature_map.dtype)
|
||||
feat = F.grid_sample(feature_map, grid, mode="bilinear",
|
||||
padding_mode="border", align_corners=False)
|
||||
return feat.squeeze(-1)
|
||||
|
||||
|
||||
def _coords_to_proj_world(coords: torch.Tensor, resolution: int, mesh_scale: torch.Tensor):
|
||||
if resolution < 1:
|
||||
raise ValueError(f"resolution must be positive, got {resolution}")
|
||||
batch_ids = coords[:, 0].long()
|
||||
if resolution == 1:
|
||||
norm = coords[:, 1:].to(torch.float32) * 0.0
|
||||
else:
|
||||
norm = coords[:, 1:].to(torch.float32) / (resolution - 1) * 2.0 - 1.0
|
||||
R = _PROJ_GRID_ROTATION.to(device=coords.device, dtype=torch.float32)
|
||||
rotated = norm @ R.T
|
||||
if mesh_scale.ndim == 0:
|
||||
scale_per_voxel = mesh_scale.expand(coords.shape[0])
|
||||
else:
|
||||
scale_per_voxel = mesh_scale.to(coords.device)[batch_ids]
|
||||
world = rotated / scale_per_voxel.unsqueeze(-1) / 2.0
|
||||
return world, batch_ids
|
||||
|
||||
|
||||
def _dense_grid_proj_world(resolution: int, mesh_scale: torch.Tensor,
|
||||
batch_size: int, device, dtype=torch.float32) -> torch.Tensor:
|
||||
one = torch.linspace(-1.0, 1.0, resolution, device=device, dtype=dtype)
|
||||
x, y, z = torch.meshgrid(one, one, one, indexing="ij")
|
||||
grid = torch.stack([x, y, z], dim=-1).reshape(-1, 3)
|
||||
R_rot = _PROJ_GRID_ROTATION.to(device=device, dtype=dtype)
|
||||
grid = grid @ R_rot.T
|
||||
grid = grid.unsqueeze(0).expand(batch_size, -1, -1).clone()
|
||||
if mesh_scale.ndim == 0:
|
||||
mesh_scale = mesh_scale.expand(batch_size)
|
||||
grid = grid / mesh_scale.to(device=device, dtype=dtype).view(-1, 1, 1) / 2.0
|
||||
return grid
|
||||
|
||||
|
||||
def _back_project_to_tokens(
|
||||
coords_world: torch.Tensor,
|
||||
feature_map: torch.Tensor,
|
||||
transform_matrix: torch.Tensor,
|
||||
camera_angle_x: torch.Tensor,
|
||||
image_resolution: int,
|
||||
batch_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if coords_world.dim() == 2:
|
||||
assert batch_ids is not None
|
||||
B = transform_matrix.shape[0]
|
||||
out = torch.zeros((coords_world.shape[0], feature_map.shape[1]),
|
||||
device=feature_map.device, dtype=feature_map.dtype)
|
||||
for b in range(B):
|
||||
mask = batch_ids == b
|
||||
if not mask.any():
|
||||
continue
|
||||
p = coords_world[mask].unsqueeze(0)
|
||||
uv, depth, valid = _project_points_to_image(
|
||||
p, transform_matrix[b:b+1], camera_angle_x[b:b+1], image_resolution)
|
||||
uv_ndc = (uv + 0.5) / image_resolution * 2.0 - 1.0
|
||||
# padding_mode='border' is load-bearing: masking out-of-frame voxels confuses
|
||||
# the SS DiT (~half the voxels go to zero, producing low poly + rotation drift).
|
||||
sampled = _sample_features(feature_map[b:b+1], uv_ndc)
|
||||
sampled = sampled.squeeze(0).transpose(0, 1)
|
||||
out[mask] = sampled
|
||||
return out
|
||||
else:
|
||||
uv, depth, valid = _project_points_to_image(
|
||||
coords_world, transform_matrix, camera_angle_x, image_resolution)
|
||||
uv_ndc = (uv + 0.5) / image_resolution * 2.0 - 1.0
|
||||
sampled = _sample_features(feature_map, uv_ndc)
|
||||
out = sampled.transpose(1, 2)
|
||||
return out
|
||||
|
||||
|
||||
def _pack_per_voxel_scalar(proj_pack: Optional[dict], key: str, eval_batch: int, device) -> torch.Tensor:
|
||||
if proj_pack is None or key not in proj_pack:
|
||||
return torch.ones((eval_batch,), device=device, dtype=torch.float32)
|
||||
t = proj_pack[key].to(device=device, dtype=torch.float32)
|
||||
if t.ndim == 0:
|
||||
return t.expand(eval_batch).clone()
|
||||
return _expand_pack(t, eval_batch)
|
||||
|
||||
|
||||
def _expand_pack(t: torch.Tensor, eval_batch: int) -> torch.Tensor:
|
||||
if eval_batch == t.shape[0]:
|
||||
return t
|
||||
if eval_batch % t.shape[0] != 0:
|
||||
raise ValueError(f"eval batch {eval_batch} is not a multiple of pack batch {t.shape[0]}")
|
||||
return t.repeat((eval_batch // t.shape[0],) + (1,) * (t.ndim - 1))
|
||||
|
||||
|
||||
def _select_stage_entry(proj_pack: dict, stage: Optional[str]):
|
||||
"""Returns (feature_map_lr, feature_map_hr_or_None, image_resolution)."""
|
||||
stages = proj_pack.get("stages")
|
||||
if stages is not None and stage is not None and stage in stages:
|
||||
entry = stages[stage]
|
||||
return entry["feature_map"], entry.get("feature_map_hr"), int(entry.get("image_resolution", 1024))
|
||||
if "feature_map" in proj_pack:
|
||||
return proj_pack["feature_map"], proj_pack.get("feature_map_hr"), int(proj_pack.get("image_resolution", 1024))
|
||||
raise ValueError(f"proj_feat_pack has no usable feature_map (stage={stage!r})")
|
||||
|
||||
|
||||
def _build_proj_cond(global_cond: torch.Tensor, image_attn_mode: str, proj_pack: Optional[dict],
|
||||
coords_world: torch.Tensor, batch_ids: Optional[torch.Tensor] = None,
|
||||
eval_batch: Optional[int] = None,
|
||||
proj_in_channels: Optional[int] = None,
|
||||
stage: Optional[str] = None,
|
||||
cond_or_uncond: Optional[list] = None):
|
||||
if image_attn_mode == "global":
|
||||
return global_cond
|
||||
if proj_pack is None:
|
||||
raise ValueError(f"image_attn_mode={image_attn_mode!r} but proj_feat_pack is missing")
|
||||
device = coords_world.device
|
||||
T = proj_pack["transform_matrix"].to(device)
|
||||
cam_angle = proj_pack["camera_angle_x"].to(device)
|
||||
feat_map_lr, feat_map_hr, image_resolution = _select_stage_entry(proj_pack, stage)
|
||||
feat_map_lr = feat_map_lr.to(device)
|
||||
if feat_map_hr is not None:
|
||||
feat_map_hr = feat_map_hr.to(device)
|
||||
if eval_batch is not None:
|
||||
T = _expand_pack(T, eval_batch)
|
||||
cam_angle = _expand_pack(cam_angle, eval_batch) if cam_angle.ndim >= 1 else cam_angle
|
||||
feat_map_lr = _expand_pack(feat_map_lr, eval_batch)
|
||||
if feat_map_hr is not None:
|
||||
feat_map_hr = _expand_pack(feat_map_hr, eval_batch)
|
||||
# Channel-count check against the trained proj_linear input. If HR is present, the
|
||||
# block expects (LR_channels + HR_channels) since we concat the sampled features.
|
||||
expected_channels = feat_map_lr.shape[1] + (feat_map_hr.shape[1] if feat_map_hr is not None else 0)
|
||||
if proj_in_channels is not None and expected_channels != proj_in_channels:
|
||||
hint = ""
|
||||
if feat_map_hr is None and expected_channels < proj_in_channels:
|
||||
hint = (" — feature_map_hr is missing for this stage. Connect a NAFModel "
|
||||
"input to Pixal3DConditioning; the shape/texture stages of this "
|
||||
"checkpoint need a NAF-upsampled HR feature map.")
|
||||
raise ValueError(
|
||||
f"proj_feat_pack[{stage!r}] has LR={feat_map_lr.shape[1]} "
|
||||
f"+ HR={feat_map_hr.shape[1] if feat_map_hr is not None else 0} "
|
||||
f"= {expected_channels} channels, sub-model expects {proj_in_channels}.{hint}"
|
||||
)
|
||||
proj_feats_lr = _back_project_to_tokens(coords_world, feat_map_lr, T, cam_angle,
|
||||
image_resolution=image_resolution,
|
||||
batch_ids=batch_ids)
|
||||
if feat_map_hr is not None:
|
||||
proj_feats_hr = _back_project_to_tokens(coords_world, feat_map_hr, T, cam_angle,
|
||||
image_resolution=image_resolution,
|
||||
batch_ids=batch_ids)
|
||||
proj_feats = torch.cat([proj_feats_lr, proj_feats_hr], dim=-1)
|
||||
else:
|
||||
proj_feats = proj_feats_lr
|
||||
# Mirror upstream's neg_cond by zeroing proj for any uncond batch slot.
|
||||
if cond_or_uncond is not None and eval_batch is not None:
|
||||
uncond_slots = [i for i, v in enumerate(cond_or_uncond) if v == 1]
|
||||
if uncond_slots:
|
||||
uncond_idx = torch.tensor(uncond_slots, device=proj_feats.device, dtype=torch.long)
|
||||
if batch_ids is None:
|
||||
proj_feats = proj_feats.clone()
|
||||
proj_feats[uncond_idx] = 0
|
||||
else:
|
||||
neg_mask = torch.isin(batch_ids, uncond_idx).unsqueeze(-1).to(proj_feats.dtype)
|
||||
proj_feats = proj_feats * (1.0 - neg_mask)
|
||||
return {"global": global_cond, "proj": proj_feats}
|
||||
|
||||
class Trellis2(nn.Module):
|
||||
def __init__(self, resolution,
|
||||
in_channels = 32,
|
||||
@ -754,6 +1046,12 @@ class Trellis2(nn.Module):
|
||||
qk_rms_norm = True,
|
||||
qk_rms_norm_cross = True,
|
||||
init_txt_model=False, # for now
|
||||
image_attn_mode_structure: str = "global",
|
||||
proj_in_channels_structure: Optional[int] = None,
|
||||
image_attn_mode_shape: str = "global",
|
||||
proj_in_channels_shape: Optional[int] = None,
|
||||
image_attn_mode_texture: str = "global",
|
||||
proj_in_channels_texture: Optional[int] = None,
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
|
||||
super().__init__()
|
||||
@ -767,22 +1065,29 @@ class Trellis2(nn.Module):
|
||||
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
|
||||
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
|
||||
}
|
||||
self.image_attn_mode_structure = image_attn_mode_structure
|
||||
self.image_attn_mode_shape = image_attn_mode_shape
|
||||
self.image_attn_mode_texture = image_attn_mode_texture
|
||||
shape_proj_kwargs = {"image_attn_mode": image_attn_mode_shape, "proj_in_channels": proj_in_channels_shape}
|
||||
tex_proj_kwargs = {"image_attn_mode": image_attn_mode_texture, "proj_in_channels": proj_in_channels_texture}
|
||||
struct_proj_kwargs = {"image_attn_mode": image_attn_mode_structure, "proj_in_channels": proj_in_channels_structure}
|
||||
txt_only = kwargs.get("txt_only", False)
|
||||
if not txt_only:
|
||||
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
|
||||
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **shape_proj_kwargs, **args)
|
||||
self.shape2txt = None
|
||||
if init_txt_model:
|
||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
||||
self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args)
|
||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **tex_proj_kwargs, **args)
|
||||
self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **shape_proj_kwargs, **args)
|
||||
args.pop("out_channels")
|
||||
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
||||
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **struct_proj_kwargs, **args)
|
||||
else:
|
||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **tex_proj_kwargs, **args)
|
||||
self.guidance_interval = [0.6, 1.0]
|
||||
self.guidance_interval_txt = [0.6, 0.9]
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
cond_or_uncond = transformer_options.get("cond_or_uncond")
|
||||
model_options = {}
|
||||
if hasattr(self, "meta"):
|
||||
model_options = self.meta
|
||||
@ -795,6 +1100,8 @@ class Trellis2(nn.Module):
|
||||
coords = model_options.get("coords", None)
|
||||
coord_counts = model_options.get("coord_counts", None)
|
||||
mode = model_options.get("generation_mode", "structure_generation")
|
||||
proj_feat_pack = model_options.get("proj_feat_pack", None)
|
||||
coord_resolution = model_options.get("coord_resolution", None)
|
||||
|
||||
is_512_run = False
|
||||
if mode == "shape_generation_512":
|
||||
@ -884,6 +1191,20 @@ class Trellis2(nn.Module):
|
||||
x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32))
|
||||
|
||||
if mode == "shape_generation":
|
||||
shape_attn = self.image_attn_mode_shape
|
||||
if shape_attn != "global":
|
||||
if coord_resolution is None:
|
||||
raise ValueError("Pixal3D shape_generation requires coord_resolution in model_options; "
|
||||
"EmptyTrellis2ShapeLatent should set it from the input voxel.")
|
||||
mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", B, batched_coords.device)
|
||||
xyz_world, batch_ids = _coords_to_proj_world(batched_coords, coord_resolution, mesh_scale)
|
||||
sub_model = self.img2shape_512 if is_512_run else self.img2shape
|
||||
stage_name = "shape_512" if is_512_run else "shape_1024"
|
||||
c_eval = _build_proj_cond(c_eval, shape_attn, proj_feat_pack, xyz_world, batch_ids,
|
||||
eval_batch=B,
|
||||
proj_in_channels=sub_model.proj_in_channels,
|
||||
stage=stage_name,
|
||||
cond_or_uncond=cond_or_uncond)
|
||||
if is_512_run:
|
||||
out = self.img2shape_512(x_st, t_eval, c_eval)
|
||||
else:
|
||||
@ -904,18 +1225,49 @@ class Trellis2(nn.Module):
|
||||
slat_feats = slat_feats[:N].repeat(B, 1)
|
||||
|
||||
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats.to(x_st.feats.device)], dim=-1))
|
||||
tex_attn = self.image_attn_mode_texture
|
||||
if tex_attn != "global":
|
||||
if coord_resolution is None:
|
||||
raise ValueError("Pixal3D texture_generation requires coord_resolution in model_options; "
|
||||
"EmptyTrellis2LatentTexture should set it from the input voxel.")
|
||||
mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", B, batched_coords.device)
|
||||
xyz_world, batch_ids = _coords_to_proj_world(batched_coords, coord_resolution, mesh_scale)
|
||||
c_eval = _build_proj_cond(c_eval, tex_attn, proj_feat_pack, xyz_world, batch_ids,
|
||||
eval_batch=B,
|
||||
proj_in_channels=self.shape2txt.proj_in_channels,
|
||||
stage="tex_1024",
|
||||
cond_or_uncond=cond_or_uncond)
|
||||
out = self.shape2txt(x_st, t_eval, c_eval)
|
||||
|
||||
else: # structure
|
||||
orig_bsz = x.shape[0]
|
||||
struct_attn = self.image_attn_mode_structure
|
||||
if shape_rule and orig_bsz > 1:
|
||||
half = orig_bsz // 2
|
||||
x_eval = x[half:]
|
||||
t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep
|
||||
out = self.structure_model(x_eval, t_eval, cond)
|
||||
struct_cond = cond
|
||||
if struct_attn != "global":
|
||||
mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", half, x.device)
|
||||
grid_xyz = _dense_grid_proj_world(16, mesh_scale, half, device=x.device)
|
||||
struct_cond = _build_proj_cond(cond, struct_attn, proj_feat_pack, grid_xyz,
|
||||
eval_batch=half,
|
||||
proj_in_channels=self.structure_model.proj_in_channels,
|
||||
stage="ss",
|
||||
cond_or_uncond=cond_or_uncond)
|
||||
out = self.structure_model(x_eval, t_eval, struct_cond)
|
||||
out = out.repeat(2, 1, 1, 1, 1)
|
||||
else:
|
||||
out = self.structure_model(x, timestep, context)
|
||||
struct_cond = context
|
||||
if struct_attn != "global":
|
||||
mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", orig_bsz, x.device)
|
||||
grid_xyz = _dense_grid_proj_world(16, mesh_scale, orig_bsz, device=x.device)
|
||||
struct_cond = _build_proj_cond(context, struct_attn, proj_feat_pack, grid_xyz,
|
||||
eval_batch=orig_bsz,
|
||||
proj_in_channels=self.structure_model.proj_in_channels,
|
||||
stage="ss",
|
||||
cond_or_uncond=cond_or_uncond)
|
||||
out = self.structure_model(x, timestep, struct_cond)
|
||||
|
||||
if not_struct_mode:
|
||||
if mask is not None:
|
||||
|
||||
301
comfy/ldm/trellis2/naf/model.py
Normal file
301
comfy/ldm/trellis2/naf/model.py
Normal file
@ -0,0 +1,301 @@
|
||||
"""NAF (Neighborhood Attention Filtering) feature upsampler.
|
||||
|
||||
Vendored from valeoai/NAF (Apache-2.0):
|
||||
https://github.com/valeoai/NAF — src/model/naf.py + src/layers/{convolutions,attentions,rope}.py
|
||||
Used by Pixal3D's shape/texture conditioning to produce
|
||||
the 2x-upsampled half of the 2048-channel proj feature map.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# Pure-torch neighborhood attention (replaces natten.na2d / na2d_qk + na2d_av).
|
||||
|
||||
def upsample_lr_slice(src_lr: torch.Tensor, lr_dh: int, lr_dw: int,
|
||||
hr_h_range: Tuple[int, int], hr_w_range: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Slice a LR-layout tensor [B, h_lr, w_lr, n, C], permute to BCHW, and
|
||||
nearest-exact upsample only the region covering [hr_h_range, hr_w_range].
|
||||
Returns BCHW at hr_h_end-hr_h_start x hr_w_end-hr_w_start (no padding for
|
||||
out-of-bounds regions)."""
|
||||
B = src_lr.shape[0]
|
||||
n = src_lr.shape[-2]
|
||||
C = src_lr.shape[-1]
|
||||
h_hr_start, h_hr_end = hr_h_range
|
||||
w_hr_start, w_hr_end = hr_w_range
|
||||
# LR positions covering [h_hr_start, h_hr_end). Nearest-exact maps HR p → p // D.
|
||||
lr_h_start = h_hr_start // lr_dh
|
||||
lr_h_end = (h_hr_end - 1) // lr_dh + 1
|
||||
lr_w_start = w_hr_start // lr_dw
|
||||
lr_w_end = (w_hr_end - 1) // lr_dw + 1
|
||||
lr_slice = src_lr[:, lr_h_start:lr_h_end, lr_w_start:lr_w_end]
|
||||
lh, lw = lr_slice.shape[1], lr_slice.shape[2]
|
||||
lr_bcd = lr_slice.permute(0, 3, 4, 1, 2).reshape(B * n, C, lh, lw).contiguous()
|
||||
up = F.interpolate(lr_bcd, scale_factor=(lr_dh, lr_dw), mode="nearest-exact")
|
||||
offset_h = h_hr_start - lr_h_start * lr_dh
|
||||
offset_w = w_hr_start - lr_w_start * lr_dw
|
||||
return up[:, :, offset_h:offset_h + (h_hr_end - h_hr_start),
|
||||
offset_w:offset_w + (w_hr_end - w_hr_start)]
|
||||
|
||||
|
||||
def na2d_pure(
|
||||
q: torch.Tensor, # [B, H, W, n_heads, d_qk] at HR.
|
||||
k_lr: torch.Tensor, # [B, h_lr, w_lr, n_heads, d_qk] at LR
|
||||
v_lr: torch.Tensor, # [B, h_lr, w_lr, n_heads, d_v] at LR
|
||||
kernel_size: Tuple[int, int], # (Kh, Kw) attention window.
|
||||
dilation: Tuple[int, int], # (Dh, Dw) stride within the unrolled K/V grid; also the LR→HR upsample factor.
|
||||
scale: float, # 1 / sqrt(d_qk) scaling for the Q·K scores.
|
||||
tile: int = 128, # Spatial tile size (output positions per tile)
|
||||
v_chunk: int = 64 # Sub-divide d_v into chunks of this size when computing attn·V. None disables chunking.
|
||||
) -> torch.Tensor: # [B, H, W, n_heads, d_v] attended features.
|
||||
"""Neighborhood attention in pure torch via F.unfold + per-tile slicing.
|
||||
|
||||
K and V are passed at LR resolution and upsampled (nearest-exact) per-tile only
|
||||
for the slice the unfold needs. Avoids the [B, n*d, H, W] HR allocations for K
|
||||
(512 MB) and V (2 GB) at tex_1024 fp16. Spatial tiling bounds the per-tile
|
||||
F.unfold blob; `v_chunk` further slices d_v so attn·V is computed in C-sized
|
||||
chunks (attn is reused, computed once from Q/K).
|
||||
|
||||
"""
|
||||
B, H, W, n, d_qk = q.shape
|
||||
d_v = v_lr.shape[-1]
|
||||
Kh, Kw = kernel_size
|
||||
Dh, Dw = dilation
|
||||
pad_h, pad_w = (Kh // 2) * Dh, (Kw // 2) * Dw
|
||||
|
||||
q_ = q.permute(0, 3, 4, 1, 2).contiguous() # [B, n, d_qk, H, W]
|
||||
out = torch.empty((B, n, d_v, H, W), device=q.device, dtype=q.dtype)
|
||||
|
||||
th = min(tile, H) if tile else H
|
||||
tw = min(tile, W) if tile else W
|
||||
chunk = v_chunk if (v_chunk and v_chunk < d_v) else d_v
|
||||
|
||||
for h0 in range(0, H, th):
|
||||
for w0 in range(0, W, tw):
|
||||
h1, w1 = min(h0 + th, H), min(w0 + tw, W)
|
||||
t_h, t_w = h1 - h0, w1 - w0
|
||||
|
||||
# Padded HR region the unfold needs (kernel span = (K-1)*D + 1).
|
||||
h_src_start = max(0, h0 - pad_h)
|
||||
h_src_end = min(H, h1 + pad_h)
|
||||
w_src_start = max(0, w0 - pad_w)
|
||||
w_src_end = min(W, w1 + pad_w)
|
||||
pad_top = max(0, pad_h - h0)
|
||||
pad_bot = max(0, (h1 + pad_h) - H)
|
||||
pad_lft = max(0, pad_w - w0)
|
||||
pad_rgt = max(0, (w1 + pad_w) - W)
|
||||
|
||||
# Upsample only the tile region from k_lr / v_lr.
|
||||
k_tile = upsample_lr_slice(k_lr, Dh, Dw,
|
||||
(h_src_start, h_src_end),
|
||||
(w_src_start, w_src_end))
|
||||
v_tile = upsample_lr_slice(v_lr, Dh, Dw,
|
||||
(h_src_start, h_src_end),
|
||||
(w_src_start, w_src_end))
|
||||
if pad_top or pad_bot or pad_lft or pad_rgt:
|
||||
k_tile = F.pad(k_tile, [pad_lft, pad_rgt, pad_top, pad_bot])
|
||||
v_tile = F.pad(v_tile, [pad_lft, pad_rgt, pad_top, pad_bot])
|
||||
|
||||
# Q·K → attention weights (small: KK=81 per output position).
|
||||
KK = Kh * Kw
|
||||
k_w = F.unfold(k_tile, kernel_size=(Kh, Kw), dilation=(Dh, Dw), padding=0)
|
||||
k_w = k_w.view(B, n, d_qk, KK, t_h * t_w).permute(0, 1, 4, 3, 2) # [B, n, t, KK, d_qk]
|
||||
q_tile = q_[:, :, :, h0:h1, w0:w1].permute(0, 1, 3, 4, 2).reshape(B, n, t_h * t_w, 1, d_qk)
|
||||
scores = torch.matmul(q_tile, k_w.transpose(-1, -2)) * scale
|
||||
attn = scores.softmax(dim=-1)
|
||||
del k_w, scores, q_tile, k_tile
|
||||
|
||||
# attn · V, chunked over d_v.
|
||||
for c0 in range(0, d_v, chunk):
|
||||
c1 = min(c0 + chunk, d_v)
|
||||
v_w = F.unfold(v_tile[:, c0:c1], kernel_size=(Kh, Kw),dilation=(Dh, Dw), padding=0) # [B*n, (c1-c0)*KK, t]
|
||||
v_w = v_w.view(B, n, c1 - c0, KK, t_h * t_w).permute(0, 1, 4, 3, 2)
|
||||
out_chunk = torch.matmul(attn, v_w).squeeze(-2) # [B, n, t, c1-c0]
|
||||
out_chunk = out_chunk.view(B, n, t_h, t_w, c1 - c0).permute(0, 1, 4, 2, 3)
|
||||
out[:, :, c0:c1, h0:h1, w0:w1] = out_chunk
|
||||
del v_w, out_chunk
|
||||
del attn, v_tile
|
||||
|
||||
return out.permute(0, 3, 4, 1, 2).contiguous() # [B, H, W, n, d_v]
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
"""Window-restricted cross-attention. No learnable parameters; the model's
|
||||
capacity lives entirely in the ImageEncoder convs."""
|
||||
|
||||
def __init__(self, dim: int, num_heads: int, kernel_size: Tuple[int, int] = (9, 9)):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim must be divisible by num_heads"
|
||||
self.num_heads = num_heads
|
||||
self.kernel_size = kernel_size
|
||||
self.scale = (dim // num_heads) ** -0.5
|
||||
|
||||
@staticmethod
|
||||
def _split_heads_lr(x: torch.Tensor, num_heads: int) -> torch.Tensor:
|
||||
"""[B, n*d, h, w] -> [B, h, w, n, d] at the input resolution (no upsample)."""
|
||||
B, C, H, W = x.shape
|
||||
return x.view(B, num_heads, C // num_heads, H, W).permute(0, 3, 4, 1, 2).contiguous()
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
# q is [B, C, Hq, Wq] at HR; k and v are at LR (Hk, Wk). We KEEP k and v at LR
|
||||
# na2d_pure upsamples only the tile slice it needs.
|
||||
hq, wq = q.shape[-2:]
|
||||
hk, wk = k.shape[-2:]
|
||||
dilation = (hq // hk, wq // wk)
|
||||
B, C, _, _ = q.shape
|
||||
q = q.view(B, self.num_heads, C // self.num_heads, hq, wq).permute(0, 3, 4, 1, 2).contiguous()
|
||||
k_lr = self._split_heads_lr(k, self.num_heads).to(q.dtype)
|
||||
v_lr = self._split_heads_lr(v, self.num_heads).to(q.dtype)
|
||||
out = na2d_pure(q, k_lr, v_lr, self.kernel_size, dilation, self.scale)
|
||||
# [B, H, W, n, d] -> [B, n*d, H, W]
|
||||
return out.permute(0, 3, 4, 1, 2).contiguous().view(B, -1, hq, wq)
|
||||
|
||||
|
||||
# RoPE positional embedding
|
||||
|
||||
def rope_rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat([-x2, x1], dim=-1)
|
||||
|
||||
|
||||
class RoPE(nn.Module):
|
||||
def __init__(self, embed_dim: int, num_heads: int, base: float = 100.0):
|
||||
super().__init__()
|
||||
assert embed_dim % (4 * num_heads) == 0
|
||||
self.num_heads = num_heads
|
||||
self.D_head = embed_dim // num_heads
|
||||
self.base = base
|
||||
self.register_buffer("periods", torch.empty(self.D_head // 4), persistent=True) # loaded from the checkpoint
|
||||
self._cached_key = None
|
||||
self._cached_cos_sin = None
|
||||
|
||||
def _cos_sin(self, H: int, W: int, dtype: torch.dtype):
|
||||
"""cos/sin only depend on (H, W) and the output dtype (periods are fixed
|
||||
once loaded from the checkpoint), so cache them — saves the meshgrid /
|
||||
angle / cos / sin / tile / flatten on every forward."""
|
||||
key = (H, W, dtype)
|
||||
if self._cached_key == key and self._cached_cos_sin is not None:
|
||||
return self._cached_cos_sin
|
||||
device = self.periods.device
|
||||
coords_h = torch.arange(0.5, H, device=device, dtype=torch.float32) / H
|
||||
coords_w = torch.arange(0.5, W, device=device, dtype=torch.float32) / W
|
||||
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # [H, W, 2]
|
||||
coords = coords.flatten(0, 1) * 2.0 - 1.0 # [HW, 2]
|
||||
angles = 2 * math.pi * coords[:, :, None] / self.periods.to(coords.dtype)[None, None, :] # [HW, 2, D//4]
|
||||
angles = angles.flatten(1, 2).tile(2) # [HW, D]
|
||||
cos = torch.cos(angles).to(dtype)
|
||||
sin = torch.sin(angles).to(dtype)
|
||||
self._cached_cos_sin = (cos, sin)
|
||||
self._cached_key = key
|
||||
return cos, sin
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: [B, n*D_head, H, W]
|
||||
B, C, H, W = x.shape
|
||||
n = self.num_heads
|
||||
D = C // n
|
||||
x = x.view(B, n, D, H, W).permute(0, 1, 3, 4, 2).reshape(B, n, H * W, D)
|
||||
cos, sin = self._cos_sin(H, W, x.dtype)
|
||||
x = (x * cos) + (rope_rotate_half(x) * sin)
|
||||
x = x.view(B, n, H, W, D).permute(0, 1, 4, 2, 3).reshape(B, n * D, H, W)
|
||||
return x
|
||||
|
||||
|
||||
# Image encoder
|
||||
|
||||
class EncBlock(nn.Module):
|
||||
def __init__(self, channels: int, kernel_size: int, num_groups: int = 8):
|
||||
super().__init__()
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=channels)
|
||||
self.conv1 = nn.Conv2d(channels, channels, kernel_size=kernel_size,
|
||||
padding=kernel_size // 2, padding_mode="reflect", bias=True)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=channels)
|
||||
self.conv2 = nn.Conv2d(channels, channels, kernel_size=kernel_size,
|
||||
padding=kernel_size // 2, padding_mode="reflect", bias=True)
|
||||
self.activation_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.activation_fn(x)
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.activation_fn(x)
|
||||
x = self.conv2(x)
|
||||
return x # no skip connection
|
||||
|
||||
|
||||
def _encoder(in_dim: int, hidden_dim: int, kernel_size: int = 1, ks_res: int = 1, num_layers: int = 2) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_dim, hidden_dim, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode="reflect", bias=True),
|
||||
*[EncBlock(hidden_dim, kernel_size=ks_res) for _ in range(num_layers)],
|
||||
)
|
||||
|
||||
|
||||
class ImageEncoder(nn.Module):
|
||||
"""Two parallel conv stacks (1x1 + 3x3) producing dim/2 channels each, then concat,
|
||||
spatial average-pool to target size, RoPE-embed positions."""
|
||||
|
||||
def __init__(self, in_channels: int = 3, out_channels: int = 256,
|
||||
heads_rope: int = 4, rope_base: float = 100.0, img_layers: int = 2):
|
||||
super().__init__()
|
||||
half = out_channels // 2
|
||||
self.encoder = _encoder(in_channels, half, kernel_size=1, ks_res=1, num_layers=img_layers)
|
||||
self.sem_encoder = _encoder(in_channels, half, kernel_size=3, ks_res=3, num_layers=img_layers)
|
||||
self.rope = RoPE(embed_dim=out_channels, num_heads=heads_rope, base=rope_base)
|
||||
|
||||
def forward(self, x: torch.Tensor, output_size: Tuple[int, int]) -> torch.Tensor:
|
||||
# Avoid running the conv stacks on >4× the target resolution.
|
||||
out_h, out_w = output_size
|
||||
if x.shape[-2] > 4 * out_h or x.shape[-1] > 4 * out_w:
|
||||
x = F.interpolate(x, size=(min(x.shape[-2], 4 * out_h),
|
||||
min(x.shape[-1], 4 * out_w)),
|
||||
mode="bilinear", align_corners=False)
|
||||
x = torch.cat([self.encoder(x), self.sem_encoder(x)], dim=1)
|
||||
x = F.adaptive_avg_pool2d(x, output_size=output_size)
|
||||
x = self.rope(x)
|
||||
return x
|
||||
|
||||
|
||||
# Top-level NAF model.
|
||||
|
||||
class NAF(nn.Module):
|
||||
"""NAF feature upsampler."""
|
||||
|
||||
def __init__(
|
||||
self, dim: int = 256, # internal channel dimension of the ImageEncoder
|
||||
heads_attn: int = 4, # attention heads in the windowed cross-attn
|
||||
heads_rope: int = 4, # heads for RoPE position encoding (must divide dim)
|
||||
kernel_size: int = 9, # square kernel for the neighborhood attention window
|
||||
rope_base: float = 100.0, # base for RoPE frequency periods
|
||||
img_layers: int = 2 # number of EncBlocks in each conv stack
|
||||
):
|
||||
super().__init__()
|
||||
self.image_encoder = ImageEncoder(in_channels=3, out_channels=dim, heads_rope=heads_rope, rope_base=rope_base, img_layers=img_layers)
|
||||
self.upsampler = CrossAttention(dim=dim, num_heads=heads_attn, kernel_size=(kernel_size, kernel_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image: torch.Tensor, # [B, 3, H_img, W_img] in [0, 1].
|
||||
features: torch.Tensor, # [B, C, H_feat, W_feat] low-resolution features (any C).
|
||||
output_size: Tuple[int, int] # (H_out, W_out) target spatial resolution for the upsampled features.
|
||||
) -> torch.Tensor: # [B, C, H_out, W_out] upsampled features.
|
||||
"""Upsample low-res feature map to output_size, guided by the image."""
|
||||
q = self.image_encoder(image, output_size=output_size)
|
||||
k = F.adaptive_avg_pool2d(q, output_size=features.shape[-2:])
|
||||
return self.upsampler(q, k, features)
|
||||
|
||||
|
||||
def build_naf_from_state_dict(state_dict: dict) -> NAF:
|
||||
"""Instantiate NAF with the default hyperparams and load the given state_dict.
|
||||
|
||||
The published NAF release uses the default constructor (dim=256, heads_attn=4,
|
||||
heads_rope=4, kernel_size=9, rope_base=100, img_layers=2)."""
|
||||
model = NAF()
|
||||
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||
if unexpected:
|
||||
raise ValueError(f"Unexpected keys in NAF state_dict: {sorted(unexpected)[:8]}...")
|
||||
return model
|
||||
@ -75,13 +75,9 @@ def sparse_conv3d_forward(self, x):
|
||||
|
||||
class LayerNorm32(nn.LayerNorm):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
w = self.weight.to(torch.float32) if self.weight is not None else None
|
||||
b = self.bias.to(torch.float32) if self.bias is not None else None
|
||||
|
||||
o = F.layer_norm(x, self.normalized_shape, w, b, self.eps)
|
||||
return o.to(x_dtype)
|
||||
w = self.weight.to(x.dtype) if self.weight is not None else None
|
||||
b = self.bias.to(x.dtype) if self.bias is not None else None
|
||||
return F.layer_norm(x, self.normalized_shape, w, b, self.eps)
|
||||
|
||||
class SparseConvNeXtBlock3d(nn.Module):
|
||||
def __init__(
|
||||
@ -204,7 +200,6 @@ class SparseResBlockC2S3d(nn.Module):
|
||||
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
|
||||
self.conv1 = SparseConv3d(channels, self.out_channels * 8, 3)
|
||||
self.conv2 = SparseConv3d(self.out_channels, self.out_channels, 3)
|
||||
self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1))
|
||||
if pred_subdiv:
|
||||
self.to_subdiv = SparseLinear(channels, 8)
|
||||
self.updown = SparseChannel2Spatial(2)
|
||||
@ -215,15 +210,16 @@ class SparseResBlockC2S3d(nn.Module):
|
||||
x = x.to(dtype)
|
||||
subdiv = self.to_subdiv(x)
|
||||
h = x.replace(self.norm1(x.feats))
|
||||
h = h.replace(F.silu(h.feats))
|
||||
h = h.replace(F.silu(h.feats, inplace=True))
|
||||
h = self.conv1(h)
|
||||
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
|
||||
h = self.updown(h, subdiv_binarized)
|
||||
x = self.updown(x, subdiv_binarized)
|
||||
h = h.replace(self.norm2(h.feats))
|
||||
h = h.replace(F.silu(h.feats))
|
||||
h = h.replace(F.silu(h.feats, inplace=True))
|
||||
h = self.conv2(h)
|
||||
h = h + self.skip_connection(x)
|
||||
skip_repeat = self.out_channels // (self.channels // 8)
|
||||
h.feats.view(h.feats.shape[0], x.feats.shape[1], skip_repeat).add_(x.feats.unsqueeze(-1))
|
||||
if self.pred_subdiv:
|
||||
return h, subdiv
|
||||
else:
|
||||
@ -1211,13 +1207,12 @@ def flexible_dual_grid_to_mesh(
|
||||
edge_neighbor_voxel = coords.reshape(N, 1, 1, 3) + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset # (N, 3, 4, 3)
|
||||
connected_voxel = edge_neighbor_voxel[intersected_flag] # (M, 4, 3)
|
||||
M = connected_voxel.shape[0]
|
||||
# flatten connected voxel coords and lookup
|
||||
conn_flat_b = torch.zeros((M * 4,), dtype=torch.long, device=coords.device)
|
||||
conn_x = connected_voxel.reshape(-1, 3)[:, 0].to(torch.int32)
|
||||
conn_y = connected_voxel.reshape(-1, 3)[:, 1].to(torch.int32)
|
||||
conn_z = connected_voxel.reshape(-1, 3)[:, 2].to(torch.int32)
|
||||
# flatten connected voxel coords and lookup. In-place to avoid extra memory allocation.
|
||||
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
||||
conn_flat = conn_flat_b * (W * H * D) + conn_x * (H * D) + conn_y * D + conn_z
|
||||
cv = connected_voxel.reshape(-1, 3)
|
||||
conn_flat = cv[:, 0].long() * (H * D)
|
||||
conn_flat.add_(cv[:, 1].long() * D)
|
||||
conn_flat.add_(cv[:, 2].long())
|
||||
|
||||
conn_indices = torch_hashmap.lookup_flat(conn_flat).reshape(M, 4).int()
|
||||
connected_voxel_valid = (conn_indices != 0xffffffff).all(dim=1)
|
||||
|
||||
@ -113,13 +113,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
||||
return unet_config
|
||||
|
||||
if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
|
||||
def _detect_proj(sub_prefix: str, name: str):
|
||||
key = '{}{}.blocks.0.cross_attn.proj_linear.weight'.format(key_prefix, sub_prefix)
|
||||
if key in state_dict_keys:
|
||||
unet_config["image_attn_mode_{}".format(name)] = "proj"
|
||||
unet_config["proj_in_channels_{}".format(name)] = int(state_dict[key].shape[1])
|
||||
|
||||
if '{}img2shape.blocks.0.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or \
|
||||
'{}img2shape.blocks.0.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
|
||||
unet_config = {}
|
||||
unet_config["image_model"] = "trellis2"
|
||||
|
||||
unet_config["init_txt_model"] = False
|
||||
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
|
||||
unet_config["init_txt_model"] = True
|
||||
unet_config["init_txt_model"] = (
|
||||
'{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or
|
||||
'{}shape2txt.blocks.29.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys
|
||||
)
|
||||
|
||||
unet_config["resolution"] = 64
|
||||
if metadata is not None:
|
||||
@ -127,14 +135,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
unet_config["resolution"] = 32
|
||||
|
||||
unet_config["num_heads"] = 12
|
||||
|
||||
_detect_proj("img2shape", "shape")
|
||||
_detect_proj("shape2txt", "texture")
|
||||
_detect_proj("structure_model", "structure")
|
||||
return unet_config
|
||||
|
||||
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: # trellis2 texture
|
||||
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or \
|
||||
'{}shape2txt.blocks.29.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: # trellis2 texture
|
||||
unet_config = {}
|
||||
unet_config["image_model"] = "trellis2"
|
||||
unet_config["resolution"] = 64
|
||||
unet_config["num_heads"] = 12
|
||||
unet_config["txt_only"] = True
|
||||
_detect_proj("shape2txt", "texture")
|
||||
return unet_config
|
||||
|
||||
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
|
||||
|
||||
@ -1325,6 +1325,7 @@ class Trellis2(supported_models_base.BASE):
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
"multiplier": 1.0
|
||||
}
|
||||
|
||||
memory_usage_factor = 3.5
|
||||
|
||||
@ -276,13 +276,30 @@ class RescaleCFG:
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, multiplier):
|
||||
model_sampling = model.get_model_object("model_sampling")
|
||||
is_x0_space = not isinstance(model_sampling, comfy.model_sampling.EPS)
|
||||
|
||||
def rescale_cfg(args):
|
||||
x_orig = args["input"]
|
||||
cond_scale = args["cond_scale"]
|
||||
|
||||
if is_x0_space:
|
||||
# Flow-matching / X0 models: cond_denoised/uncond_denoised are x_0 estimates,
|
||||
# so the eps↔v conversion below would be wrong. Rescale directly in x_0 space.
|
||||
x_0_cond = args["cond_denoised"]
|
||||
x_0_uncond = args["uncond_denoised"]
|
||||
x_0_cfg = x_0_uncond + cond_scale * (x_0_cond - x_0_uncond)
|
||||
dims = tuple(range(1, x_0_cond.ndim))
|
||||
ro_pos = x_0_cond.std(dim=dims, keepdim=True)
|
||||
ro_cfg = x_0_cfg.std(dim=dims, keepdim=True).clamp(min=1e-8)
|
||||
x_0_rescaled = x_0_cfg * (ro_pos / ro_cfg)
|
||||
x_0_final = multiplier * x_0_rescaled + (1.0 - multiplier) * x_0_cfg
|
||||
return x_orig - x_0_final
|
||||
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
sigma = args["sigma"]
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
|
||||
x_orig = args["input"]
|
||||
|
||||
#rescale cfg has to be done on v-pred model output
|
||||
x = x_orig / (sigma * sigma + 1.0)
|
||||
|
||||
@ -1,13 +1,60 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types, io
|
||||
from comfy.ldm.trellis2.vae import SparseTensor
|
||||
from comfy.ldm.trellis2.model import _build_proj_transform_matrix, _project_points_to_image
|
||||
from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict
|
||||
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from PIL import Image
|
||||
import logging
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
|
||||
ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
|
||||
Pixal3DProjPack = io.Custom("PIXAL3D_PROJ_PACK")
|
||||
NAFModel = io.Custom("NAF_MODEL")
|
||||
|
||||
|
||||
# Pixal3D trains in a 90°-X-rotated grid frame (F_p). We un-rotate decoder outputs for
|
||||
# user-facing previews/meshes, then re-rotate before feeding coords back to the shape DiT.
|
||||
|
||||
def _pixal3d_unrotate_voxel_data(data: torch.Tensor) -> torch.Tensor:
|
||||
if data.ndim == 4:
|
||||
return data.flip(-1).permute(0, 1, 3, 2).contiguous()
|
||||
if data.ndim == 5:
|
||||
return data.flip(-1).permute(0, 1, 2, 4, 3).contiguous()
|
||||
raise ValueError(f"unexpected voxel shape {tuple(data.shape)}")
|
||||
|
||||
|
||||
def _pixal3d_rerotate_voxel_data(data: torch.Tensor) -> torch.Tensor:
|
||||
if data.ndim == 4:
|
||||
return data.permute(0, 1, 3, 2).flip(-1).contiguous()
|
||||
if data.ndim == 5:
|
||||
return data.permute(0, 1, 2, 4, 3).flip(-1).contiguous()
|
||||
raise ValueError(f"unexpected voxel shape {tuple(data.shape)}")
|
||||
|
||||
|
||||
def _pixal3d_unrotate_vertices(vertices: torch.Tensor) -> torch.Tensor:
|
||||
if vertices.numel() == 0:
|
||||
return vertices
|
||||
x, y, z = vertices.unbind(-1)
|
||||
return torch.stack([-x, y, -z], dim=-1).contiguous()
|
||||
|
||||
|
||||
def _pixal3d_unrotate_sparse_coords(coords: torch.Tensor, resolution: int) -> torch.Tensor:
|
||||
if coords.numel() == 0:
|
||||
return coords
|
||||
R1 = resolution - 1
|
||||
if coords.shape[-1] == 4:
|
||||
b, i, j, k = coords.unbind(-1)
|
||||
return torch.stack([b, R1 - i, j, R1 - k], dim=-1).contiguous()
|
||||
if coords.shape[-1] == 3:
|
||||
i, j, k = coords.unbind(-1)
|
||||
return torch.stack([R1 - i, j, R1 - k], dim=-1).contiguous()
|
||||
raise ValueError(f"unexpected coord shape {tuple(coords.shape)}")
|
||||
|
||||
def prepare_trellis_vae_for_decode(vae, sample_shape):
|
||||
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
|
||||
@ -163,6 +210,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||
trellis_vae = vae.first_stage_model
|
||||
coord_counts = samples.get("coord_counts")
|
||||
pixal3d_mode = samples.get("model_options", {}).get("proj_feat_pack") is not None
|
||||
|
||||
samples = samples["samples"]
|
||||
if coord_counts is None:
|
||||
@ -188,6 +236,10 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
coords_list = [stage_tensor.coords for stage_tensor in stage_tensors]
|
||||
subs.append(SparseTensor.from_tensor_list(feats_list, coords_list))
|
||||
|
||||
if pixal3d_mode:
|
||||
for m in mesh:
|
||||
m.vertices = _pixal3d_unrotate_vertices(m.vertices)
|
||||
|
||||
face_list = [m.faces for m in mesh]
|
||||
vert_list = [m.vertices for m in mesh]
|
||||
if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list):
|
||||
@ -224,6 +276,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||
trellis_vae = vae.first_stage_model
|
||||
coord_counts = samples.get("coord_counts")
|
||||
pixal3d_mode = samples.get("model_options", {}).get("proj_feat_pack") is not None
|
||||
|
||||
samples = samples["samples"]
|
||||
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
|
||||
@ -237,7 +290,17 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
color_feats = voxel.feats[:, :3]
|
||||
voxel_coords = voxel.coords#[:, 1:]
|
||||
|
||||
voxel = Types.VOXEL(voxel_coords, color_feats, 1024)
|
||||
if voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
|
||||
spatial = voxel_coords[:, -3:] if voxel_coords.shape[-1] == 4 else voxel_coords
|
||||
max_idx = int(spatial.max().item()) + 1
|
||||
tex_resolution = next((r for r in (256, 512, 1024, 1536, 2048) if r >= max_idx), max_idx)
|
||||
else:
|
||||
tex_resolution = 1024
|
||||
|
||||
if pixal3d_mode:
|
||||
voxel_coords = _pixal3d_unrotate_sparse_coords(voxel_coords, resolution=tex_resolution)
|
||||
|
||||
voxel = Types.VOXEL(voxel_coords, color_feats, tex_resolution)
|
||||
return IO.NodeOutput(voxel)
|
||||
|
||||
class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
@ -274,7 +337,10 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
if current_res != resolution:
|
||||
ratio = current_res // resolution
|
||||
decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
|
||||
out = Types.VOXEL(decoded.squeeze(1).float())
|
||||
voxel_data = decoded.squeeze(1).float()
|
||||
if samples.get("model_options", {}).get("proj_feat_pack") is not None:
|
||||
voxel_data = _pixal3d_unrotate_voxel_data(voxel_data)
|
||||
out = Types.VOXEL(voxel_data)
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
class Trellis2UpsampleCascade(IO.ComfyNode):
|
||||
@ -540,7 +606,6 @@ class Trellis2Conditioning(IO.ComfyNode):
|
||||
cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2))
|
||||
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
|
||||
else:
|
||||
import logging
|
||||
logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.")
|
||||
cropped_np = rgba_np.astype(np.float32) / 255.0
|
||||
|
||||
@ -587,7 +652,12 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
|
||||
"Shape structure input. Accepts either a voxel structure "
|
||||
"or upsampled voxel coordinates from a previous cascade stage."
|
||||
)
|
||||
)
|
||||
),
|
||||
Pixal3DProjPack.Input(
|
||||
"proj_feat_pack",
|
||||
optional=True,
|
||||
tooltip="Pixal3D pixel-aligned projection pack from Pixal3DConditioning. Leave empty for vanilla Trellis2.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
@ -595,21 +665,26 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, voxel):
|
||||
# to accept the upscaled coords
|
||||
def execute(cls, voxel, proj_feat_pack=None):
|
||||
is_512_pass = False
|
||||
coord_resolution = None
|
||||
upsampled = hasattr(voxel, "upsampled")
|
||||
if upsampled:
|
||||
if hasattr(voxel, "resolutions") and voxel.resolutions is not None:
|
||||
coord_resolution = int(voxel.resolutions[0].item()) // 16
|
||||
voxel = voxel.data
|
||||
|
||||
if not upsampled:
|
||||
decoded = voxel.data.unsqueeze(1)
|
||||
voxel_data = voxel.data
|
||||
if proj_feat_pack is not None:
|
||||
voxel_data = _pixal3d_rerotate_voxel_data(voxel_data)
|
||||
decoded = voxel_data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
is_512_pass = True
|
||||
coord_resolution = int(decoded.shape[-1])
|
||||
|
||||
else:
|
||||
coords = voxel.int()
|
||||
is_512_pass = False
|
||||
|
||||
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
|
||||
in_channels = 32
|
||||
@ -620,8 +695,13 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
|
||||
generation_mode = "shape_generation_512"
|
||||
else:
|
||||
generation_mode = "shape_generation"
|
||||
model_options = {"generation_mode": generation_mode, "coords": coords, "coord_counts": counts}
|
||||
if coord_resolution is not None:
|
||||
model_options["coord_resolution"] = coord_resolution
|
||||
if proj_feat_pack is not None:
|
||||
model_options["proj_feat_pack"] = proj_feat_pack
|
||||
return IO.NodeOutput({"samples": latent, "coords": coords, "coord_counts": counts, "type": "trellis2",
|
||||
"model_options": {"generation_mode": generation_mode, "coords": coords, "coord_counts": counts}})
|
||||
"model_options": model_options})
|
||||
|
||||
class EmptyTrellis2LatentTexture(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -638,6 +718,11 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
|
||||
)
|
||||
),
|
||||
IO.Latent.Input("shape_latent"),
|
||||
Pixal3DProjPack.Input(
|
||||
"proj_feat_pack",
|
||||
optional=True,
|
||||
tooltip="Pixal3D pixel-aligned projection pack from Pixal3DConditioning. Leave empty for vanilla Trellis2.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
@ -645,15 +730,22 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, voxel, shape_latent):
|
||||
def execute(cls, voxel, shape_latent, proj_feat_pack=None):
|
||||
channels = 32
|
||||
coord_resolution = None
|
||||
upsampled = hasattr(voxel, "upsampled")
|
||||
if upsampled:
|
||||
if hasattr(voxel, "resolutions") and voxel.resolutions is not None:
|
||||
coord_resolution = int(voxel.resolutions[0].item()) // 16
|
||||
voxel = voxel.data
|
||||
|
||||
if not upsampled:
|
||||
decoded = voxel.data.unsqueeze(1)
|
||||
voxel_data = voxel.data
|
||||
if proj_feat_pack is not None:
|
||||
voxel_data = _pixal3d_rerotate_voxel_data(voxel_data)
|
||||
decoded = voxel_data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
coord_resolution = int(decoded.shape[-1])
|
||||
else:
|
||||
coords = voxel.int()
|
||||
|
||||
@ -664,9 +756,13 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
|
||||
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
|
||||
|
||||
latent = torch.zeros(batch_size, channels, max_tokens, 1)
|
||||
model_options = {"generation_mode": "texture_generation", "coords": coords, "coord_counts": counts, "shape_slat": shape_latent}
|
||||
if coord_resolution is not None:
|
||||
model_options["coord_resolution"] = coord_resolution
|
||||
if proj_feat_pack is not None:
|
||||
model_options["proj_feat_pack"] = proj_feat_pack
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts,
|
||||
"model_options": {"generation_mode": "texture_generation",
|
||||
"coords": coords, "coord_counts": counts, "shape_slat": shape_latent}})
|
||||
"model_options": model_options})
|
||||
|
||||
|
||||
class EmptyTrellis2LatentStructure(IO.ComfyNode):
|
||||
@ -677,27 +773,441 @@ class EmptyTrellis2LatentStructure(IO.ComfyNode):
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
|
||||
Pixal3DProjPack.Input(
|
||||
"proj_feat_pack",
|
||||
optional=True,
|
||||
tooltip="Pixal3D pixel-aligned projection pack. Leave empty for vanilla Trellis2.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
]
|
||||
)
|
||||
@classmethod
|
||||
def execute(cls, batch_size):
|
||||
in_channels = 8
|
||||
def execute(cls, batch_size, proj_feat_pack=None):
|
||||
# Trellis2.forward slices x[:, :8] and pads out to 32; KSampler residual math
|
||||
# needs the empty latent to match latent_format (32-channel).
|
||||
in_channels = 32
|
||||
resolution = 16
|
||||
latent = torch.zeros(batch_size, in_channels, resolution, resolution, resolution)
|
||||
output = {
|
||||
"samples": latent,
|
||||
"type": "trellis2",
|
||||
}
|
||||
if proj_feat_pack is not None:
|
||||
output["model_options"] = {"proj_feat_pack": proj_feat_pack}
|
||||
return IO.NodeOutput(output)
|
||||
|
||||
def _dinov3_patches_to_2d(tokens, image_size, patch_size=16):
|
||||
h_p = w_p = image_size // patch_size
|
||||
n_patches = h_p * w_p
|
||||
n_reg = tokens.shape[1] - 1 - n_patches
|
||||
if n_reg < 0 or tokens.shape[1] != 1 + n_reg + n_patches:
|
||||
raise ValueError(
|
||||
f"_dinov3_patches_to_2d: got {tokens.shape[1]} tokens, expected "
|
||||
f"1 (CLS) + N_reg + {h_p}*{w_p}={n_patches} patches at image_size={image_size}, "
|
||||
f"patch_size={patch_size}. Inferred N_reg={n_reg} which is invalid."
|
||||
)
|
||||
start = 1 + n_reg
|
||||
patches = tokens[:, start:start + n_patches]
|
||||
return patches.transpose(1, 2).reshape(tokens.shape[0], -1, h_p, w_p).contiguous()
|
||||
|
||||
|
||||
def _fov_from_moge_intrinsics(moge_intrinsics: torch.Tensor) -> float:
|
||||
fx = moge_intrinsics[..., 0, 0].float()
|
||||
fov = 2.0 * torch.atan(0.5 / fx.clamp(min=1e-4))
|
||||
return float(fov.mean().item())
|
||||
|
||||
|
||||
def _run_dinov3_with_patches(model, cropped_pil, image_size):
|
||||
# Pixal3D's cross-attn was trained against CLS + registers only (~5 tokens), not the
|
||||
# full patch grid. The patch grid goes to the proj branch via patches_2d.
|
||||
model_internal = model.model
|
||||
torch_device = comfy.model_management.get_torch_device()
|
||||
resized = cropped_pil.resize((image_size, image_size), Image.Resampling.LANCZOS)
|
||||
img_np = np.array(resized).astype(np.float32) / 255.0
|
||||
img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device)
|
||||
img_t = (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
||||
model_internal.image_size = image_size
|
||||
tokens = model_internal(img_t, skip_norm_elementwise=True)[0]
|
||||
patches = _dinov3_patches_to_2d(tokens, image_size)
|
||||
h_p = w_p = image_size // 16
|
||||
n_reg = tokens.shape[1] - 1 - h_p * w_p
|
||||
global_tokens = tokens[:, :1 + n_reg]
|
||||
return {"tokens": global_tokens, "patches_2d": patches}
|
||||
|
||||
|
||||
def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
|
||||
img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
pil_img = Image.fromarray(img_np)
|
||||
pil_mask = Image.fromarray(mask_np)
|
||||
max_size = max(pil_img.size)
|
||||
scale = min(1.0, max_image_size / max_size)
|
||||
if scale < 1.0:
|
||||
new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale)
|
||||
pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST)
|
||||
scene_size = (pil_img.width, pil_img.height)
|
||||
rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8)
|
||||
rgba_np[:, :, :3] = np.array(pil_img)
|
||||
rgba_np[:, :, 3] = np.array(pil_mask)
|
||||
alpha = rgba_np[:, :, 3]
|
||||
bbox_coords = np.argwhere(alpha > 0.8 * 255)
|
||||
if len(bbox_coords) > 0:
|
||||
y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1])
|
||||
y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1])
|
||||
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
|
||||
# Upstream pads the bbox by 10% — encoders were trained with that breathing room.
|
||||
size = max(y_max - y_min, x_max - x_min)
|
||||
size = int(size * 1.1)
|
||||
half = size // 2
|
||||
crop_x1 = int(center_x - half)
|
||||
crop_y1 = int(center_y - half)
|
||||
crop_x2 = crop_x1 + 2 * half
|
||||
crop_y2 = crop_y1 + 2 * half
|
||||
crop_bbox = (crop_x1, crop_y1, crop_x2, crop_y2)
|
||||
rgba_pil = Image.fromarray(rgba_np)
|
||||
cropped_rgba = rgba_pil.crop(crop_bbox)
|
||||
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
|
||||
else:
|
||||
logging.warning("Mask for the image is empty. Pixal3D requires a clean foreground mask.")
|
||||
cropped_np = rgba_np.astype(np.float32) / 255.0
|
||||
crop_bbox = (0, 0, scene_size[0], scene_size[1])
|
||||
fg = cropped_np[:, :, :3]
|
||||
alpha_float = cropped_np[:, :, 3:4]
|
||||
composite_np = fg * alpha_float
|
||||
composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
|
||||
return Image.fromarray(composite_uint8), crop_bbox, scene_size
|
||||
|
||||
|
||||
class Pixal3DConditioning(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Pixal3DConditioning",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
IO.ClipVision.Input("clip_vision_model", tooltip="DINOv3 ViT-L/16 ClipVision."),
|
||||
IO.Image.Input("image"),
|
||||
IO.Mask.Input("mask"),
|
||||
IO.Float.Input(
|
||||
"camera_angle_x", default=0.2, min=0.0175, max=2.9671, step=0.001,
|
||||
tooltip="Horizontal FOV in radians (upstream demo default 0.2). "
|
||||
"Overridden by moge_geometry if connected.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"mesh_scale", default=1.0, min=0.1, max=4.0, step=0.01,
|
||||
tooltip="Mesh scale; 1.0 means unit cube.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"distance_override", default=0.0, min=0.0, max=10.0, step=0.001,
|
||||
tooltip="Override camera distance directly. 0 = auto-derive from FOV.",
|
||||
),
|
||||
io.Custom("MOGE_GEOMETRY").Input(
|
||||
"moge_geometry",
|
||||
optional=True,
|
||||
tooltip="If connected, camera_angle_x is recovered from MoGe.",
|
||||
),
|
||||
NAFModel.Input(
|
||||
"naf_model",
|
||||
optional=True,
|
||||
tooltip="Optional NAF feature upsampler. Required for shape/texture stages "
|
||||
"to match upstream's trained feature distribution.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
IO.Conditioning.Output(display_name="negative"),
|
||||
Pixal3DProjPack.Output(display_name="proj_feat_pack"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip_vision_model, image, mask, camera_angle_x, mesh_scale,
|
||||
distance_override=0.0,
|
||||
moge_geometry=None, naf_model=None) -> IO.NodeOutput:
|
||||
if image.ndim == 3:
|
||||
image = image.unsqueeze(0)
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
batch_size = image.shape[0]
|
||||
if mask.shape[0] == 1 and batch_size > 1:
|
||||
mask = mask.expand(batch_size, -1, -1)
|
||||
elif mask.shape[0] != batch_size:
|
||||
raise ValueError(f"Pixal3DConditioning mask batch {mask.shape[0]} != image batch {batch_size}")
|
||||
|
||||
if moge_geometry is not None and "intrinsics" in moge_geometry:
|
||||
camera_angle_x = _fov_from_moge_intrinsics(moge_geometry["intrinsics"])
|
||||
|
||||
device = comfy.model_management.intermediate_device()
|
||||
|
||||
cond_512_list, cond_1024_list = [], []
|
||||
patches_512_list, patches_1024_list = [], []
|
||||
cropped_pil_list = []
|
||||
crop_bbox_list, scene_size_list = [], []
|
||||
|
||||
torch_device = comfy.model_management.get_torch_device()
|
||||
for b in range(batch_size):
|
||||
item_image = image[b]
|
||||
item_mask = mask[b] if mask.size(0) > 1 else mask[0]
|
||||
cropped_pil, crop_bbox, scene_size = _crop_image_with_mask(
|
||||
item_image, item_mask, max_image_size=1024)
|
||||
crop_bbox_list.append(crop_bbox)
|
||||
scene_size_list.append(scene_size)
|
||||
cropped_pil_list.append(cropped_pil)
|
||||
|
||||
cond_512 = _run_dinov3_with_patches(clip_vision_model, cropped_pil, 512)
|
||||
cond_1024 = _run_dinov3_with_patches(clip_vision_model, cropped_pil, 1024)
|
||||
cond_512_list.append(cond_512["tokens"].to(device))
|
||||
cond_1024_list.append(cond_1024["tokens"].to(device))
|
||||
patches_512_list.append(cond_512["patches_2d"].to(device))
|
||||
patches_1024_list.append(cond_1024["patches_2d"].to(device))
|
||||
|
||||
global_512 = torch.cat(cond_512_list, dim=0)
|
||||
global_1024 = torch.cat(cond_1024_list, dim=0)
|
||||
|
||||
fm_512_dino = torch.cat(patches_512_list, dim=0)
|
||||
fm_1024_dino = torch.cat(patches_1024_list, dim=0)
|
||||
|
||||
# Upstream samples the LR DINO grid AND the NAF HR grid separately at projected
|
||||
# 3D points, then cats sampled features along channels. Back-projection (in model.py)
|
||||
# mirrors that — here we just stash LR + optional HR per stage.
|
||||
# NAF targets per stage: shape_512=512, shape_1024=512, tex_1024=1024.
|
||||
def _naf_hr(lr_feat, image_pil_list, image_size, naf_target):
|
||||
if naf_model is None or naf_target is None:
|
||||
return None
|
||||
# Run NAF in the input feature dtype (typically fp16 since DINO/ClipVision
|
||||
# loads that way). The previous .float() cast doubled NAF memory by forcing
|
||||
# full fp32 — at tex_1024/target=1024 that's ~10 GB on its own. Model
|
||||
# weights need to match input dtype since PyTorch conv ops error out on
|
||||
# mixed fp16-input/fp32-weight.
|
||||
target_dtype = lr_feat.dtype
|
||||
if next(naf_model.parameters()).dtype != target_dtype:
|
||||
naf_model.to(dtype=target_dtype)
|
||||
imgs = torch.stack([
|
||||
torch.from_numpy(
|
||||
np.array(p.resize((image_size, image_size), Image.Resampling.LANCZOS))
|
||||
.astype(np.float32) / 255.0
|
||||
).permute(2, 0, 1)
|
||||
for p in image_pil_list
|
||||
], dim=0).to(torch_device).to(target_dtype)
|
||||
|
||||
hr = naf_model(imgs, lr_feat.to(torch_device).to(target_dtype), naf_target)
|
||||
return hr.to(device)
|
||||
|
||||
hr_shape_512 = _naf_hr(fm_512_dino, cropped_pil_list, 512, (512, 512))
|
||||
hr_shape_1024 = _naf_hr(fm_1024_dino, cropped_pil_list, 1024, (512, 512))
|
||||
hr_tex_1024 = _naf_hr(fm_1024_dino, cropped_pil_list, 1024, (1024, 1024))
|
||||
|
||||
# distance_from_fov: grid_point (-1, 0, 0) projects to pixel (0, image_resolution-1).
|
||||
camera_angle_x = float(camera_angle_x)
|
||||
if distance_override > 0:
|
||||
distance = float(distance_override)
|
||||
else:
|
||||
distance = 0.5 / math.tan(camera_angle_x / 2.0) / float(mesh_scale)
|
||||
cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=device, dtype=torch.float32)
|
||||
dist_t = torch.tensor([distance] * batch_size, device=device, dtype=torch.float32)
|
||||
scale_t = torch.tensor([float(mesh_scale)] * batch_size, device=device, dtype=torch.float32)
|
||||
T = _build_proj_transform_matrix(dist_t, batch_size, device=device, dtype=torch.float32)
|
||||
|
||||
proj_pack = {
|
||||
"stages": {
|
||||
"ss": {"feature_map": fm_512_dino, "feature_map_hr": None, "image_resolution": 512},
|
||||
"shape_512": {"feature_map": fm_512_dino, "feature_map_hr": hr_shape_512, "image_resolution": 512},
|
||||
"shape_1024": {"feature_map": fm_1024_dino, "feature_map_hr": hr_shape_1024,"image_resolution": 1024},
|
||||
"tex_1024": {"feature_map": fm_1024_dino, "feature_map_hr": hr_tex_1024, "image_resolution": 1024},
|
||||
},
|
||||
"transform_matrix": T,
|
||||
"camera_angle_x": cam_angle_t,
|
||||
"mesh_scale": scale_t,
|
||||
"distance": dist_t,
|
||||
"patch_size": 16,
|
||||
"crop_bboxes": crop_bbox_list,
|
||||
"scene_sizes": scene_size_list,
|
||||
}
|
||||
|
||||
# global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024
|
||||
# (Trellis2.forward swaps context↔embeds for non-structure HR stages).
|
||||
neg_global = torch.zeros_like(global_512)
|
||||
neg_embeds = torch.zeros_like(global_1024)
|
||||
positive = [[global_512, {"embeds": global_1024}]]
|
||||
negative = [[neg_global, {"embeds": neg_embeds}]]
|
||||
return IO.NodeOutput(positive, negative, proj_pack)
|
||||
|
||||
|
||||
def _project_vertices_to_image_uv(vertices_world, transform_matrix, camera_angle_x, image_resolution):
|
||||
points = vertices_world.unsqueeze(0).float()
|
||||
T = transform_matrix.unsqueeze(0).float() if transform_matrix.ndim == 2 else transform_matrix.float()
|
||||
cam = camera_angle_x.unsqueeze(0) if camera_angle_x.ndim == 0 else camera_angle_x
|
||||
uv_pix, depth, valid = _project_points_to_image(points, T, cam.float(), image_resolution)
|
||||
uv = uv_pix.squeeze(0) / image_resolution
|
||||
return uv, depth.squeeze(0), valid.squeeze(0)
|
||||
|
||||
|
||||
def _crop_uv_to_scene_pixels(uv_crop, crop_bbox, scene_image_size):
|
||||
crop_x1, crop_y1, crop_x2, crop_y2 = crop_bbox
|
||||
crop_w = max(1, crop_x2 - crop_x1)
|
||||
crop_h = max(1, crop_y2 - crop_y1)
|
||||
px = uv_crop[:, 0] * crop_w + crop_x1
|
||||
py = uv_crop[:, 1] * crop_h + crop_y1
|
||||
W, H = scene_image_size
|
||||
return torch.stack([px.clamp(0, W - 1), py.clamp(0, H - 1)], dim=-1)
|
||||
|
||||
|
||||
class Pixal3DAlignObject(IO.ComfyNode):
|
||||
"""Pixal3D paper §3.3 Global Alignment for a single object.
|
||||
|
||||
Solves (scale, translation) aligning the mesh to MoGe's per-pixel point map. Requires
|
||||
MoGe to have been computed on the same resized scene image as Pixal3DConditioning."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Pixal3DAlignObject",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
Pixal3DProjPack.Input("proj_feat_pack", tooltip="The proj pack produced by Pixal3DConditioning for this object."),
|
||||
io.Custom("MOGE_GEOMETRY").Input("moge_geometry", tooltip="MoGe geometry computed on the original scene image."),
|
||||
IO.Mask.Input(
|
||||
"object_mask",
|
||||
optional=True,
|
||||
tooltip="Optional per-object scene-space mask. If connected, only vertices whose projected pixel falls inside the mask contribute to the alignment solve.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"batch_index",
|
||||
default=0, min=0, max=1024,
|
||||
tooltip="Which batch slot of the proj_feat_pack/MoGe geometry corresponds to this object.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output("aligned_mesh"),
|
||||
IO.Float.Output(display_name="scale"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, proj_feat_pack, moge_geometry, object_mask=None, batch_index=0) -> IO.NodeOutput:
|
||||
vertices = mesh.vertices
|
||||
faces = mesh.faces
|
||||
if vertices.ndim == 3:
|
||||
vertices_one = vertices[0]
|
||||
faces_one = faces[0]
|
||||
else:
|
||||
vertices_one = vertices
|
||||
faces_one = faces
|
||||
|
||||
T = proj_feat_pack["transform_matrix"][batch_index:batch_index + 1]
|
||||
cam_angle = proj_feat_pack["camera_angle_x"][batch_index:batch_index + 1]
|
||||
mesh_scale = proj_feat_pack["mesh_scale"][batch_index]
|
||||
image_resolution = int(proj_feat_pack.get("image_resolution", 1024))
|
||||
crop_bbox = proj_feat_pack["crop_bboxes"][batch_index]
|
||||
pack_scene_size = proj_feat_pack.get("scene_sizes", [None] * (batch_index + 1))[batch_index]
|
||||
moge_points = moge_geometry["points"]
|
||||
moge_mask = moge_geometry["mask"]
|
||||
if moge_points.ndim != 4:
|
||||
raise ValueError(f"MoGe points expected [B, H, W, 3]; got {tuple(moge_points.shape)}")
|
||||
scene_H, scene_W = moge_points.shape[1], moge_points.shape[2]
|
||||
if pack_scene_size is not None and pack_scene_size != (scene_W, scene_H):
|
||||
raise ValueError(
|
||||
f"Pixal3DAlignObject: MoGe geometry was computed on a {scene_W}x{scene_H} image, "
|
||||
f"but the proj_feat_pack's bbox lives in a {pack_scene_size[0]}x{pack_scene_size[1]} "
|
||||
"image. Run MoGe on the same resized scene image Pixal3DConditioning used."
|
||||
)
|
||||
|
||||
# Compose VaeDecodeShapeTrellis's R_y(180°) inverse with R_proj to map user mesh
|
||||
# space to ProjGrid world: (X, Y, Z) -> (-X, Z, Y).
|
||||
v = vertices_one.float()
|
||||
verts_world = torch.stack([-v[..., 0], v[..., 2], v[..., 1]], dim=-1)
|
||||
verts_world = verts_world / float(mesh_scale.item())
|
||||
uv_crop, _depth, valid = _project_vertices_to_image_uv(
|
||||
verts_world, T[0], cam_angle[0], image_resolution)
|
||||
scene_pixels = _crop_uv_to_scene_pixels(uv_crop, crop_bbox, (scene_W, scene_H))
|
||||
in_scene = ((scene_pixels[:, 0] >= 0) & (scene_pixels[:, 0] < scene_W) &
|
||||
(scene_pixels[:, 1] >= 0) & (scene_pixels[:, 1] < scene_H))
|
||||
sx = scene_pixels[:, 0].long().clamp(0, scene_W - 1)
|
||||
sy = scene_pixels[:, 1].long().clamp(0, scene_H - 1)
|
||||
moge_per_vertex = moge_points[batch_index, sy, sx]
|
||||
moge_mask_per_vertex = moge_mask[batch_index, sy, sx]
|
||||
keep = valid & in_scene & moge_mask_per_vertex
|
||||
if object_mask is not None:
|
||||
om = object_mask if object_mask.ndim == 2 else object_mask[batch_index]
|
||||
keep = keep & (om[sy, sx] > 0.5)
|
||||
|
||||
finite = torch.isfinite(moge_per_vertex).all(dim=-1)
|
||||
keep = keep & finite
|
||||
|
||||
kept = int(keep.sum().item())
|
||||
if kept < 8:
|
||||
scale = 1.0
|
||||
aligned = vertices_one
|
||||
else:
|
||||
P = vertices_one[keep].float()
|
||||
Q = moge_per_vertex[keep].float()
|
||||
p_mean = P.mean(dim=0, keepdim=True)
|
||||
q_mean = Q.mean(dim=0, keepdim=True)
|
||||
P_c = P - p_mean
|
||||
Q_c = Q - q_mean
|
||||
num = (P_c * Q_c).sum()
|
||||
den = (P_c * P_c).sum().clamp(min=1e-8)
|
||||
scale = float((num / den).item())
|
||||
if not (scale > 0):
|
||||
# Negative scale would mirror the mesh; treat as a camera-convention mismatch.
|
||||
logging.warning(
|
||||
f"Pixal3DAlignObject: computed scale={scale:.4f} <= 0; "
|
||||
"refusing to apply mirroring. Check camera convention alignment.")
|
||||
scale = 1.0
|
||||
aligned = vertices_one
|
||||
else:
|
||||
t = q_mean - scale * p_mean
|
||||
aligned = scale * vertices_one + t
|
||||
|
||||
if vertices.ndim == 3:
|
||||
aligned = aligned.unsqueeze(0)
|
||||
out_mesh = Types.MESH(vertices=aligned, faces=faces)
|
||||
else:
|
||||
out_mesh = Types.MESH(vertices=aligned, faces=faces_one)
|
||||
return IO.NodeOutput(out_mesh, float(scale))
|
||||
|
||||
|
||||
class LoadNAFModel(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="LoadNAFModel",
|
||||
display_name="Load NAF Model",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"naf_name",
|
||||
options=folder_paths.get_filename_list("upscale_models"),
|
||||
tooltip="NAF safetensors checkpoint (e.g. naf_release.safetensors).",
|
||||
),
|
||||
],
|
||||
outputs=[NAFModel.Output(display_name="naf_model")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, naf_name) -> IO.NodeOutput:
|
||||
path = folder_paths.get_full_path_or_raise("upscale_models", naf_name)
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
model = build_naf_from_state_dict(sd)
|
||||
device = comfy.model_management.get_torch_device()
|
||||
model = model.to(device).eval()
|
||||
return IO.NodeOutput(model)
|
||||
|
||||
|
||||
class Trellis2Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
Trellis2Conditioning,
|
||||
Pixal3DConditioning,
|
||||
Pixal3DAlignObject,
|
||||
LoadNAFModel,
|
||||
EmptyTrellis2ShapeLatent,
|
||||
EmptyTrellis2LatentStructure,
|
||||
EmptyTrellis2LatentTexture,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user