From 6ea2e5b288d14eac984ad38499fd76aa1f9295c7 Mon Sep 17 00:00:00 2001 From: "Yousef R. Gamaleldin" Date: Fri, 30 Jan 2026 23:34:48 +0200 Subject: [PATCH 01/59] init --- comfy/image_encoders/dino3.py | 240 +++++ comfy/image_encoders/dino3_large.json | 24 + comfy/ldm/trellis2/attention.py | 194 ++++ comfy/ldm/trellis2/cumesh.py | 149 ++++ comfy/ldm/trellis2/model.py | 499 +++++++++++ comfy/ldm/trellis2/vae.py | 1185 +++++++++++++++++++++++++ comfy_extras/trellis2.py | 240 +++++ 7 files changed, 2531 insertions(+) create mode 100644 comfy/image_encoders/dino3.py create mode 100644 comfy/image_encoders/dino3_large.json create mode 100644 comfy/ldm/trellis2/attention.py create mode 100644 comfy/ldm/trellis2/cumesh.py create mode 100644 comfy/ldm/trellis2/model.py create mode 100644 comfy/ldm/trellis2/vae.py create mode 100644 comfy_extras/trellis2.py diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py new file mode 100644 index 000000000..d07c2c5b8 --- /dev/null +++ b/comfy/image_encoders/dino3.py @@ -0,0 +1,240 @@ +import math +import torch +import torch.nn as nn + +from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.ldm.flux.math import apply_rope +from dino2 import Dinov2MLP as DINOv3ViTMLP, LayerScale as DINOv3ViTLayerScale + +class DINOv3ViTAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, device, dtype, operations): + super().__init__() + self.embed_dim = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.is_causal = False + + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False + self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + + self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + position_embeddings = torch.stack([cos, sin], dim = -1) + query_states, key_states = apply_rope(query_states, key_states, position_embeddings) + + attn_output, attn_weights = optimized_attention_for_device( + query_states, key_states, value_states, attention_mask, skip_reshape=True, skip_output_reshape=True + ) + + attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + +class DINOv3ViTGatedMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype) + self.act_fn = torch.nn.GELU() + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + +def get_patches_center_coordinates( + num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + + coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device) + coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device) + coords_h = coords_h / num_patches_h + coords_w = coords_w / num_patches_w + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) + coords = coords.flatten(0, 1) + coords = 2.0 * coords - 1.0 + return coords + +class DINOv3ViTRopePositionEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, rope_theta, hidden_size, num_attention_heads, image_size, patch_size, device, dtype): + super().__init__() + self.base = rope_theta + self.head_dim = hidden_size // num_attention_heads + self.num_patches_h = image_size // patch_size + self.num_patches_w = image_size // patch_size + + inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + _, _, height, width = pixel_values.shape + num_patches_h = height // self.patch_size + num_patches_w = width // self.patch_size + + device = pixel_values.device + device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu" + with torch.amp.autocast(device_type = device_type, enabled=False): + patch_coords = get_patches_center_coordinates( + num_patches_h, num_patches_w, dtype=torch.float32, device=device + ) + + angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] + angles = angles.flatten(1, 2) + angles = angles.tile(2) + + cos = torch.cos(angles) + sin = torch.sin(angles) + + dtype = pixel_values.dtype + return cos.to(dtype=dtype), sin.to(dtype=dtype) + + +class DINOv3ViTEmbeddings(nn.Module): + def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations): + super().__init__() + self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_size, device=device, dtype=dtype)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=dtype)) + self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype)) + self.patch_embeddings = operations.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype + ) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None): + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embeddings.weight.dtype + + patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) + + if bool_masked_pos is not None: + mask_token = self.mask_token.to(patch_embeddings.dtype) + patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings) + + cls_token = self.cls_token.expand(batch_size, -1, -1) + register_tokens = self.register_tokens.expand(batch_size, -1, -1) + embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) + + return embeddings + +class DINOv3ViTLayer(nn.Module): + + def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, layerscale_value, mlp_bias, intermediate_size, num_attention_heads, + device, dtype, operations): + super().__init__() + + self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps) + self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations) + self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, layerscale_value, device=device, dtype=dtype) + + self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) + + if use_gated_mlp: + self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations) + else: + self.mlp = DINOv3ViTMLP(hidden_size, device=device, dtype=dtype, operations=operations) + self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, layerscale_value, device=device, dtype=dtype) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states, _ = self.attention( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + hidden_states = self.layer_scale1(hidden_states) + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.layer_scale2(hidden_states) + hidden_states = hidden_states + residual + + return hidden_states + + +class DINOv3ViTModel(nn.Module): + def __init__(self, config, device, dtype, operations): + super().__init__() + num_hidden_layers = config["num_hidden_layers"] + hidden_size = config["hidden_size"] + num_attention_heads = config["num_attention_heads"] + num_register_tokens = config["num_register_tokens"] + intermediate_size = config["intermediate_size"] + layer_norm_eps = config["layer_norm_eps"] + layerscale_value = config["layerscale_value"] + num_channels = config["num_channels"] + patch_size = config["patch_size"] + rope_theta = config["rope_theta"] + + self.embeddings = DINOv3ViTEmbeddings( + hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size, dtype=dtype, device=device, operations=operations + ) + self.rope_embeddings = DINOv3ViTRopePositionEmbedding( + rope_theta, hidden_size, num_attention_heads, image_size=512, patch_size=patch_size, dtype=dtype, device=device + ) + self.layer = nn.ModuleList( + [DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=False, layerscale_value=layerscale_value, mlp_bias=True, + intermediate_size=intermediate_size,num_attention_heads = num_attention_heads, + dtype=dtype, device=device, operations=operations) + for _ in range(num_hidden_layers)]) + self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device) + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: torch.Tensor | None = None, + **kwargs, + ): + + pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + position_embeddings = self.rope_embeddings(pixel_values) + + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module( + hidden_states, + position_embeddings=position_embeddings, + ) + + sequence_output = self.norm(hidden_states) + pooled_output = sequence_output[:, 0, :] + + return sequence_output, None, pooled_output, None diff --git a/comfy/image_encoders/dino3_large.json b/comfy/image_encoders/dino3_large.json new file mode 100644 index 000000000..96263f0d6 --- /dev/null +++ b/comfy/image_encoders/dino3_large.json @@ -0,0 +1,24 @@ +{ + + "hidden_size": 384, + "image_size": 224, + "initializer_range": 0.02, + "intermediate_size": 1536, + "key_bias": false, + "layer_norm_eps": 1e-05, + "layerscale_value": 1.0, + "mlp_bias": true, + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "num_register_tokens": 4, + "patch_size": 16, + "pos_embed_rescale": 2.0, + "proj_bias": true, + "query_bias": true, + "rope_theta": 100.0, + "use_gated_mlp": false, + "value_bias": true, + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225] +} diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py new file mode 100644 index 000000000..9cd7d4995 --- /dev/null +++ b/comfy/ldm/trellis2/attention.py @@ -0,0 +1,194 @@ +import torch +import math +from comfy.ldm.modules.attention import optimized_attention +from typing import Tuple, Union, List +from vae import VarLenTensor + +def sparse_windowed_scaled_dot_product_self_attention( + qkv, + window_size: int, + shift_window: Tuple[int, int, int] = (0, 0, 0) +): + + serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args)) + else: + fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if optimized_attention.__name__ == 'attention_xformers': + if 'xops' not in globals(): + import xformers.ops as xops + q, k, v = qkv_feats.unbind(dim=1) + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C] + elif optimized_attention.__name__ == 'attention_flash': + if 'flash_attn' not in globals(): + import flash_attn + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + return qkv.replace(out) + +def calc_window_partition( + tensor, + window_size: Union[int, Tuple[int, ...]], + shift_window: Union[int, Tuple[int, ...]] = 0, +) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: + + DIM = tensor.coords.shape[1] - 1 + shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size + shifted_coords = tensor.coords.clone().detach() + shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) + + MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)] + NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] + OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] + + shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) + shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) + fwd_indices = torch.argsort(shifted_indices) + bwd_indices = torch.empty_like(fwd_indices) + bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) + seq_lens = torch.bincount(shifted_indices) + mask = seq_lens != 0 + seq_lens = seq_lens[mask] + + if optimized_attention.__name__ == 'attention_xformers': + if 'xops' not in globals(): + import xformers.ops as xops + attn_func_args = { + 'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + } + elif optimized_attention.__name__ == 'attention_flash': + attn_func_args = { + 'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(), + 'max_seqlen': torch.max(seq_lens) + } + + return fwd_indices, bwd_indices, seq_lens, attn_func_args + + +def sparse_scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + device = qkv.device + + s = qkv + q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] + kv_seqlen = q_seqlen + qkv = qkv.feats # [T, 3, H, C] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + device = q.device + + if isinstance(q, VarLenTensor): + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, C] + else: + s = None + N, L, H, C = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, C) # [T_Q, H, C] + + if isinstance(kv, VarLenTensor): + kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] + kv = kv.feats # [T_KV, 2, H, C] + else: + N, L, _, H, C = kv.shape + kv_seqlen = [L] * N + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + device = q.device + + if isinstance(q, VarLenTensor): + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, Ci] + else: + s = None + N, L, H, CI = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] + + if isinstance(k, VarLenTensor): + kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] + else: + N, L, H, CI, CO = *k.shape, v.shape[-1] + kv_seqlen = [L] * N + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + + if optimized_attention.__name__ == 'attention_xformers': + if 'xops' not in globals(): + import xformers.ops as xops + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + out = xops.memory_efficient_attention(q, k, v, mask)[0] + elif optimized_attention.__name__ == 'attention_flash': + if 'flash_attn' not in globals(): + import flash_attn + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + elif num_all_args == 2: + out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif num_all_args == 3: + out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif optimized_attention.__name__ == 'flash_attn_3': # TODO + if 'flash_attn_3' not in globals(): + import flash_attn_interface as flash_attn_3 + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + cu_seqlens_kv = cu_seqlens_q.clone() + max_q_seqlen = max_kv_seqlen = max(q_seqlen) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + max_q_seqlen = max(q_seqlen) + max_kv_seqlen = max(kv_seqlen) + elif num_all_args == 3: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + max_q_seqlen = max(q_seqlen) + max_kv_seqlen = max(kv_seqlen) + out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen) + + if s is not None: + return s.replace(out) + else: + return out.reshape(N, L, H, -1) diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py new file mode 100644 index 000000000..be8200341 --- /dev/null +++ b/comfy/ldm/trellis2/cumesh.py @@ -0,0 +1,149 @@ +# will contain every cuda -> pytorch operation + +import torch +from typing import Dict + + +class TorchHashMap: + def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): + device = keys.device + # use long for searchsorted + self.sorted_keys, order = torch.sort(keys.long()) + self.sorted_vals = values.long()[order] + self.default_value = torch.tensor(default_value, dtype=torch.long, device=device) + self._n = self.sorted_keys.numel() + + def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: + flat = flat_keys.long() + idx = torch.searchsorted(self.sorted_keys, flat) + found = (idx < self._n) & (self.sorted_keys[idx] == flat) + out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) + if found.any(): + out[found] = self.sorted_vals[idx[found]] + return out + +class Voxel: + def __init__( + self, + origin: list, + voxel_size: float, + coords: torch.Tensor = None, + attrs: torch.Tensor = None, + layout: Dict = {}, + device: torch.device = 'cuda' + ): + self.origin = torch.tensor(origin, dtype=torch.float32, device=device) + self.voxel_size = voxel_size + self.coords = coords + self.attrs = attrs + self.layout = layout + self.device = device + + @property + def position(self): + return (self.coords + 0.5) * self.voxel_size + self.origin[None, :] + + def split_attrs(self): + return { + k: self.attrs[:, self.layout[k]] + for k in self.layout + } + +class Mesh: + def __init__(self, + vertices, + faces, + vertex_attrs=None + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.vertex_attrs = vertex_attrs + + @property + def device(self): + return self.vertices.device + + def to(self, device, non_blocking=False): + return Mesh( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None, + ) + + def cuda(self, non_blocking=False): + return self.to('cuda', non_blocking=non_blocking) + + def cpu(self): + return self.to('cpu') + + # TODO could be an option + def fill_holes(self, max_hole_perimeter=3e-2): + import cumesh + vertices = self.vertices.cuda() + faces = self.faces.cuda() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.get_edges() + mesh.get_boundary_info() + if mesh.num_boundaries == 0: + return + mesh.get_vertex_edge_adjacency() + mesh.get_vertex_boundary_adjacency() + mesh.get_manifold_boundary_adjacency() + mesh.read_manifold_boundary_adjacency() + mesh.get_boundary_connected_components() + mesh.get_boundary_loops() + if mesh.num_boundary_loops == 0: + return + mesh.fill_holes(max_hole_perimeter=max_hole_perimeter) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + + # TODO could be an option + def simplify(self, target=1000000, verbose: bool=False, options: dict={}): + import cumesh + vertices = self.vertices.cuda() + faces = self.faces.cuda() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.simplify(target, verbose=verbose, options=options) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + +class MeshWithVoxel(Mesh, Voxel): + def __init__(self, + vertices: torch.Tensor, + faces: torch.Tensor, + origin: list, + voxel_size: float, + coords: torch.Tensor, + attrs: torch.Tensor, + voxel_shape: torch.Size, + layout: Dict = {}, + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device) + self.voxel_size = voxel_size + self.coords = coords + self.attrs = attrs + self.voxel_shape = voxel_shape + self.layout = layout + + def to(self, device, non_blocking=False): + return MeshWithVoxel( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.origin.tolist(), + self.voxel_size, + self.coords.to(device, non_blocking=non_blocking), + self.attrs.to(device, non_blocking=non_blocking), + self.voxel_shape, + self.layout, + ) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py new file mode 100644 index 000000000..a0889c4dd --- /dev/null +++ b/comfy/ldm/trellis2/model.py @@ -0,0 +1,499 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor +from typing import Optional, Tuple, Literal, Union, List +from attention import sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention +from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder + +class SparseGELU(nn.GELU): + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + +class SparseFeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + SparseLinear(channels, int(channels * mlp_ratio)), + SparseGELU(approximate="tanh"), + SparseLinear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: VarLenTensor) -> VarLenTensor: + return self.mlp(x) + +def manual_cast(tensor, dtype): + if not torch.is_autocast_enabled(): + return tensor.type(dtype) + return tensor +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) + + +class SparseMultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: + x_type = x.dtype + x = x.float() + if isinstance(x, VarLenTensor): + x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale) + else: + x = F.normalize(x, dim=-1) * self.gamma * self.scale + return x.to(x_type) + +# TODO: replace with apply_rope1 +class SparseRotaryPositionEmbedder(nn.Module): + def __init__( + self, + head_dim: int, + dim: int = 3, + rope_freq: Tuple[float, float] = (1.0, 10000.0) + ): + super().__init__() + assert head_dim % 2 == 0, "Head dim must be divisible by 2" + self.head_dim = head_dim + self.dim = dim + self.rope_freq = rope_freq + self.freq_dim = head_dim // 2 // dim + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases.unsqueeze(-2) + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q (SparseTensor): [..., N, H, D] tensor of queries + k (SparseTensor): [..., N, H, D] tensor of keys + """ + assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1" + phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}' + phases = q.get_spatial_cache(phases_cache_name) + if phases is None: + coords = q.coords[..., 1:] + phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1) + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], padn, device=phases.device), + torch.zeros(*phases.shape[:-1], padn, device=phases.device) + )], dim=-1) + q.register_spatial_cache(phases_cache_name, phases) + q_embed = q.replace(self._rotary_embedding(q.feats, phases)) + if k is None: + return q_embed + k_embed = k.replace(self._rotary_embedding(k.feats, phases)) + return q_embed, k_embed + +class SparseMultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed", "double_windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + ): + super().__init__() + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq) + + @staticmethod + def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + return x.replace(module(x.feats)) + else: + return module(x) + + @staticmethod + def _reshape_chs(x: Union[VarLenTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + return x.reshape(*shape) + else: + return x.reshape(*x.shape[:2], *shape) + + def _fused_pre(self, x: Union[VarLenTensor, torch.Tensor], num_fused: int) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + x_feats = x.feats.unsqueeze(0) + else: + x_feats = x + x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) + return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats + + def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor: + if self._type == "self": + qkv = self._linear(self.to_qkv, x) + qkv = self._fused_pre(qkv, num_fused=3) + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=-3) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + q, k = self.rope(q, k) + qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) + if self.attn_mode == "full": + h = sparse_scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + h = sparse_windowed_scaled_dot_product_self_attention( + qkv, self.window_size, shift_window=self.shift_window + ) + elif self.attn_mode == "double_windowed": + qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:]) + qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2]) + h0 = sparse_windowed_scaled_dot_product_self_attention( + qkv0, self.window_size, shift_window=(0, 0, 0) + ) + h1 = sparse_windowed_scaled_dot_product_self_attention( + qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3) + ) + h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1)) + else: + q = self._linear(self.to_q, x) + q = self._reshape_chs(q, (self.num_heads, -1)) + kv = self._linear(self.to_kv, context) + kv = self._fused_pre(kv, num_fused=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=-3) + k = self.k_rms_norm(k) + h = sparse_scaled_dot_product_attention(q, k, v) + else: + h = sparse_scaled_dot_product_attention(q, kv) + h = self._reshape_chs(h, (-1,)) + h = self._linear(self.to_out, h) + return h + +class ModulatedSparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedSparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.self_attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + return self._forward(x, mod, context) + + +class SLatFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "rope", + rope_freq: Tuple[float, float] = (1.0, 10000.0), + use_checkpoint: bool = False, + share_mod: bool = False, + initialization: str = 'vanilla', + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + dtype = None, + device = None, + operations = None, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.initialization = initialization + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = dtype + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + self.input_layer = SparseLinear(in_channels, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + rope_freq=rope_freq, + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = SparseLinear(model_channels, out_channels) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward( + self, + x: SparseTensor, + t: torch.Tensor, + cond: Union[torch.Tensor, List[torch.Tensor]], + concat_cond: Optional[SparseTensor] = None, + **kwargs + ) -> SparseTensor: + if concat_cond is not None: + x = sparse_cat([x, concat_cond], dim=-1) + if isinstance(cond, list): + cond = VarLenTensor.from_tensor_list(cond) + + h = self.input_layer(x) + h = manual_cast(h, self.dtype) + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = manual_cast(t_emb, self.dtype) + cond = manual_cast(cond, self.dtype) + + if self.pe_mode == "ape": + pe = self.pos_embedder(h.coords[:, 1:]) + h = h + manual_cast(pe, self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + + h = manual_cast(h, x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return h + +class Trellis2(nn.Module): + def __init__(self, resolution, + in_channels = 32, + out_channels = 32, + model_channels = 1536, + cond_channels = 1024, + num_blocks = 30, + num_heads = 12, + mlp_ratio = 5.3334, + share_mod = True, + qk_rms_norm = True, + qk_rms_norm_cross = True, + dtype=None, device=None, operations=None): + args = { + "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, + "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, + "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations + } + # TODO: update the names/checkpoints + self.img2shape = SLatFlowModel(resolution, in_channels=in_channels, *args) + self.shape2txt = SLatFlowModel(resolution, in_channels=in_channels*2, *args) + self.shape_generation = True + + def forward(self, x, timestep, context): + pass diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py new file mode 100644 index 000000000..1d564bca2 --- /dev/null +++ b/comfy/ldm/trellis2/vae.py @@ -0,0 +1,1185 @@ +import torch +import torch.nn as nn +from typing import List, Any, Dict, Optional, overload, Union, Tuple +from fractions import Fraction +import torch.nn.functional as F +from dataclasses import dataclass +import numpy as np +from cumesh import TorchHashMap, Mesh, MeshWithVoxel + +# TODO: determine which conv they actually use +@dataclass +class config: + CONV = "none" + +# TODO post processing +def simplify(self, target_num_faces: int, verbose: bool=False, options: dict={}): + + num_face = self.cu_mesh.num_faces() + if num_face <= target_num_faces: + return + + thresh = options.get('thresh', 1e-8) + lambda_edge_length = options.get('lambda_edge_length', 1e-2) + lambda_skinny = options.get('lambda_skinny', 1e-3) + while True: + new_num_vert, new_num_face = self.cu_mesh.simplify_step(lambda_edge_length, lambda_skinny, thresh, False) + + if new_num_face <= target_num_faces: + break + + del_num_face = num_face - new_num_face + if del_num_face / num_face < 1e-2: + thresh *= 10 + num_face = new_num_face + +class VarLenTensor: + + def __init__(self, feats: torch.Tensor, layout: List[slice]=None): + self.feats = feats + self.layout = layout if layout is not None else [slice(0, feats.shape[0])] + self._cache = {} + + @staticmethod + def layout_from_seqlen(seqlen: list) -> List[slice]: + """ + Create a layout from a tensor of sequence lengths. + """ + layout = [] + start = 0 + for l in seqlen: + layout.append(slice(start, start + l)) + start += l + return layout + + @staticmethod + def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor': + """ + Create a VarLenTensor from a list of tensors. + """ + feats = torch.cat(tensor_list, dim=0) + layout = [] + start = 0 + for tensor in tensor_list: + layout.append(slice(start, start + tensor.shape[0])) + start += tensor.shape[0] + return VarLenTensor(feats, layout) + + def __len__(self) -> int: + return len(self.layout) + + @property + def shape(self) -> torch.Size: + return torch.Size([len(self.layout), *self.feats.shape[1:]]) + + def dim(self) -> int: + return len(self.shape) + + @property + def ndim(self) -> int: + return self.dim() + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @property + def seqlen(self) -> torch.LongTensor: + if 'seqlen' not in self._cache: + self._cache['seqlen'] = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) + return self._cache['seqlen'] + + @property + def cum_seqlen(self) -> torch.LongTensor: + if 'cum_seqlen' not in self._cache: + self._cache['cum_seqlen'] = torch.cat([ + torch.tensor([0], dtype=torch.long, device=self.device), + self.seqlen.cumsum(dim=0) + ], dim=0) + return self._cache['cum_seqlen'] + + @property + def batch_boardcast_map(self) -> torch.LongTensor: + """ + Get the broadcast map for the varlen tensor. + """ + if 'batch_boardcast_map' not in self._cache: + self._cache['batch_boardcast_map'] = torch.repeat_interleave( + torch.arange(len(self.layout), device=self.device), + self.seqlen, + ) + return self._cache['batch_boardcast_map'] + + @overload + def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... + + def to(self, *args, **kwargs) -> 'VarLenTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) + return self.replace(new_feats) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'VarLenTensor': + new_feats = self.feats.cpu() + return self.replace(new_feats) + + def cuda(self) -> 'VarLenTensor': + new_feats = self.feats.cuda() + return self.replace(new_feats) + + def half(self) -> 'VarLenTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'VarLenTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'VarLenTensor': + new_feats = self.feats.detach() + return self.replace(new_feats) + + def reshape(self, *shape) -> 'VarLenTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['VarLenTensor']: + return varlen_unbind(self, dim) + + def replace(self, feats: torch.Tensor) -> 'VarLenTensor': + new_tensor = VarLenTensor( + feats=feats, + layout=self.layout, + ) + new_tensor._cache = self._cache + return new_tensor + + def to_dense(self, max_length=None) -> torch.Tensor: + N = len(self) + L = max_length or self.seqlen.max().item() + spatial = self.feats.shape[1:] + idx = torch.arange(L, device=self.device).unsqueeze(0).expand(N, L) + mask = (idx < self.seqlen.unsqueeze(1)) + mapping = mask.reshape(-1).cumsum(dim=0) - 1 + dense = self.feats[mapping] + dense = dense.reshape(N, L, *spatial) + return dense, mask + + def __neg__(self) -> 'VarLenTensor': + return self.replace(-self.feats) + + def __elemwise__(self, other: Union[torch.Tensor, 'VarLenTensor'], op: callable) -> 'VarLenTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = other[self.batch_boardcast_map] + except: + pass + if isinstance(other, VarLenTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + return new_tensor + + def __add__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.add) + + def __radd__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.add) + + def __sub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.sub) + + def __rsub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.mul) + + def __rmul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.mul) + + def __truediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.div) + + def __rtruediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, list): + assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}" + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + new_feats = [] + new_layout = [] + start = 0 + for new_idx, old_idx in enumerate(idx): + new_feats.append(self.feats[self.layout[old_idx]]) + new_layout.append(slice(start, start + len(new_feats[-1]))) + start += len(new_feats[-1]) + new_feats = torch.cat(new_feats, dim=0).contiguous() + new_tensor = VarLenTensor(feats=new_feats, layout=new_layout) + return new_tensor + + def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + if isinstance(dim, int): + dim = (dim,) + + if op =='mean': + red = self.feats.mean(dim=dim, keepdim=keepdim) + elif op =='sum': + red = self.feats.sum(dim=dim, keepdim=keepdim) + elif op == 'prod': + red = self.feats.prod(dim=dim, keepdim=keepdim) + else: + raise ValueError(f"Unsupported reduce operation: {op}") + + if dim is None or 0 in dim: + return red + + red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen) + return red + + def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='mean', dim=dim, keepdim=keepdim) + + def sum(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='sum', dim=dim, keepdim=keepdim) + + def prod(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='prod', dim=dim, keepdim=keepdim) + + def std(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + mean = self.mean(dim=dim, keepdim=True) + mean2 = self.replace(self.feats ** 2).mean(dim=dim, keepdim=True) + std = (mean2 - mean ** 2).sqrt() + return std + + def __repr__(self) -> str: + return f"VarLenTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" + +def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]: + + if dim == 0: + return [input[i] for i in range(len(input))] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] + + +class SparseTensor(VarLenTensor): + + SparseTensorData = None + + @overload + def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ... + + @overload + def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ... + + def __init__(self, *args, **kwargs): + # Lazy import of sparse tensor backend + if self.SparseTensorData is None: + import importlib + if config.CONV == 'torchsparse': + self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor + elif config.CONV == 'spconv': + self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor + + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if 'data' in kwargs else 0 + + if method_id == 0: + feats, coords, shape = args + (None,) * (3 - len(args)) + if 'feats' in kwargs: + feats = kwargs['feats'] + del kwargs['feats'] + if 'coords' in kwargs: + coords = kwargs['coords'] + del kwargs['coords'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + + if config.CONV == 'torchsparse': + self.data = self.SparseTensorData(feats, coords, **kwargs) + elif config.CONV == 'spconv': + spatial_shape = list(coords.max(0)[0] + 1) + self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs) + self.data._features = feats + else: + self.data = { + 'feats': feats, + 'coords': coords, + } + elif method_id == 1: + data, shape = args + (None,) * (2 - len(args)) + if 'data' in kwargs: + data = kwargs['data'] + del kwargs['data'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + + self.data = data + + self._shape = shape + self._scale = kwargs.get('scale', (Fraction(1, 1), Fraction(1, 1), Fraction(1, 1))) + self._spatial_cache = kwargs.get('spatial_cache', {}) + + @staticmethod + def from_tensor_list(feats_list: List[torch.Tensor], coords_list: List[torch.Tensor]) -> 'SparseTensor': + """ + Create a SparseTensor from a list of tensors. + """ + feats = torch.cat(feats_list, dim=0) + coords = [] + for i, coord in enumerate(coords_list): + coord = torch.cat([torch.full_like(coord[:, :1], i), coord[:, 1:]], dim=1) + coords.append(coord) + coords = torch.cat(coords, dim=0) + return SparseTensor(feats, coords) + + def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Convert a SparseTensor to list of tensors. + """ + feats_list = [] + coords_list = [] + for s in self.layout: + feats_list.append(self.feats[s]) + coords_list.append(self.coords[s]) + return feats_list, coords_list + + def __len__(self) -> int: + return len(self.layout) + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + return layout + + def __cal_spatial_shape(self, coords): + return torch.Size((coords[:, 1:].max(0)[0] + 1).tolist()) + + @property + def shape(self) -> torch.Size: + if self._shape is None: + self._shape = self.__cal_shape(self.feats, self.coords) + return self._shape + + @property + def layout(self) -> List[slice]: + layout = self.get_spatial_cache('layout') + if layout is None: + layout = self.__cal_layout(self.coords, self.shape[0]) + self.register_spatial_cache('layout', layout) + return layout + + @property + def spatial_shape(self) -> torch.Size: + spatial_shape = self.get_spatial_cache('shape') + if spatial_shape is None: + spatial_shape = self.__cal_spatial_shape(self.coords) + self.register_spatial_cache('shape', spatial_shape) + return spatial_shape + + @property + def feats(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.F + elif config.CONV == 'spconv': + return self.data.features + else: + return self.data['feats'] + + @feats.setter + def feats(self, value: torch.Tensor): + if config.CONV == 'torchsparse': + self.data.F = value + elif config.CONV == 'spconv': + self.data.features = value + else: + self.data['feats'] = value + + @property + def coords(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.C + elif config.CONV == 'spconv': + return self.data.indices + else: + return self.data['coords'] + + @coords.setter + def coords(self, value: torch.Tensor): + if config.CONV == 'torchsparse': + self.data.C = value + elif config.CONV == 'spconv': + self.data.indices = value + else: + self.data['coords'] = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @property + def seqlen(self) -> torch.LongTensor: + seqlen = self.get_spatial_cache('seqlen') + if seqlen is None: + seqlen = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) + self.register_spatial_cache('seqlen', seqlen) + return seqlen + + @property + def cum_seqlen(self) -> torch.LongTensor: + cum_seqlen = self.get_spatial_cache('cum_seqlen') + if cum_seqlen is None: + cum_seqlen = torch.cat([ + torch.tensor([0], dtype=torch.long, device=self.device), + self.seqlen.cumsum(dim=0) + ], dim=0) + self.register_spatial_cache('cum_seqlen', cum_seqlen) + return cum_seqlen + + @property + def batch_boardcast_map(self) -> torch.LongTensor: + """ + Get the broadcast map for the varlen tensor. + """ + batch_boardcast_map = self.get_spatial_cache('batch_boardcast_map') + if batch_boardcast_map is None: + batch_boardcast_map = torch.repeat_interleave( + torch.arange(len(self.layout), device=self.device), + self.seqlen, + ) + self.register_spatial_cache('batch_boardcast_map', batch_boardcast_map) + return batch_boardcast_map + + @overload + def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... + + def to(self, *args, **kwargs) -> 'SparseTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) + new_coords = self.coords.to(device=device, non_blocking=non_blocking, copy=copy) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'SparseTensor': + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> 'SparseTensor': + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> 'SparseTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'SparseTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'SparseTensor': + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def reshape(self, *shape) -> 'SparseTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['SparseTensor']: + return sparse_unbind(self, dim) + + def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + if config.CONV == 'torchsparse': + new_data = self.SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif config.CONV == 'spconv': + new_data = self.SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + else: + new_data = { + 'feats': feats, + 'coords': self.data['coords'] if coords is None else coords, + } + new_tensor = SparseTensor( + new_data, + shape=torch.Size([self._shape[0]] + list(feats.shape[1:])) if self._shape is not None else None, + scale=self._scale, + spatial_cache=self._spatial_cache + ) + return new_tensor + + def to_dense(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.dense() + elif config.CONV == 'spconv': + return self.data.dense() + else: + spatial_shape = self.spatial_shape + ret = torch.zeros(*self.shape, *spatial_shape, dtype=self.dtype, device=self.device) + idx = [self.coords[:, 0], slice(None)] + self.coords[:, 1:].unbind(1) + ret[tuple(idx)] = self.feats + return ret + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) + coords = torch.cat([ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], dim=1).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + new_cache = {} + for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __elemwise__(self, other: Union[torch.Tensor, VarLenTensor], op: callable) -> 'SparseTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = other[self.batch_boardcast_map] + except: + pass + if isinstance(other, VarLenTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, list): + assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}" + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + new_coords = [] + new_feats = [] + new_layout = [] + new_shape = torch.Size([len(idx)] + list(self.shape[1:])) + start = 0 + for new_idx, old_idx in enumerate(idx): + new_coords.append(self.coords[self.layout[old_idx]].clone()) + new_coords[-1][:, 0] = new_idx + new_feats.append(self.feats[self.layout[old_idx]]) + new_layout.append(slice(start, start + len(new_coords[-1]))) + start += len(new_coords[-1]) + new_coords = torch.cat(new_coords, dim=0).contiguous() + new_feats = torch.cat(new_feats, dim=0).contiguous() + new_tensor = SparseTensor(feats=new_feats, coords=new_coords, shape=new_shape) + new_tensor.register_spatial_cache('layout', new_layout) + return new_tensor + + def clear_spatial_cache(self) -> None: + """ + Clear all spatial caches. + """ + self._spatial_cache = {} + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + def __repr__(self) -> str: + return f"SparseTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] + +class SparseLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(SparseLinear, self).__init__(in_features, out_features, bias) + + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + + + +class SparseUnetVaeEncoder(nn.Module): + """ + Sparse Swin Transformer Unet VAE model. + """ + def __init__( + self, + in_channels: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + down_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = SparseLinear(in_channels, model_channels[0]) + self.to_latent = SparseLinear(model_channels[-1], 2 * latent_channels) + + self.blocks = nn.ModuleList([]) + for i in range(len(num_blocks)): + self.blocks.append(nn.ModuleList([])) + for j in range(num_blocks[i]): + self.blocks[-1].append( + globals()[block_type[i]]( + model_channels[i], + **block_args[i], + ) + ) + if i < len(num_blocks) - 1: + self.blocks[-1].append( + globals()[down_block_type[i]]( + model_channels[i], + model_channels[i+1], + **block_args[i], + ) + ) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward(self, x: SparseTensor, sample_posterior=False, return_raw=False): + h = self.input_layer(x) + h = h.type(self.dtype) + for i, res in enumerate(self.blocks): + for j, block in enumerate(res): + h = block(h) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.to_latent(h) + + # Sample from the posterior distribution + mean, logvar = h.feats.chunk(2, dim=-1) + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + z = h.replace(z) + + if return_raw: + return z, mean, logvar + else: + return z + + + +class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder): + def __init__( + self, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + down_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + ): + super().__init__( + 6, + model_channels, + latent_channels, + num_blocks, + block_type, + down_block_type, + block_args, + use_fp16, + ) + + def forward(self, vertices: SparseTensor, intersected: SparseTensor, sample_posterior=False, return_raw=False): + x = vertices.replace(torch.cat([ + vertices.feats - 0.5, + intersected.feats.float() - 0.5, + ], dim=1)) + return super().forward(x, sample_posterior, return_raw) + +class SparseUnetVaeDecoder(nn.Module): + """ + Sparse Swin Transformer Unet VAE model. + """ + def __init__( + self, + out_channels: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + up_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.out_channels = out_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.use_fp16 = use_fp16 + self.pred_subdiv = pred_subdiv + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.low_vram = False + + self.output_layer = SparseLinear(model_channels[-1], out_channels) + self.from_latent = SparseLinear(latent_channels, model_channels[0]) + + self.blocks = nn.ModuleList([]) + for i in range(len(num_blocks)): + self.blocks.append(nn.ModuleList([])) + for j in range(num_blocks[i]): + self.blocks[-1].append( + globals()[block_type[i]]( + model_channels[i], + **block_args[i], + ) + ) + if i < len(num_blocks) - 1: + self.blocks[-1].append( + globals()[up_block_type[i]]( + model_channels[i], + model_channels[i+1], + pred_subdiv=pred_subdiv, + **block_args[i], + ) + ) + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward(self, x: SparseTensor, guide_subs: Optional[List[SparseTensor]] = None, return_subs: bool = False) -> SparseTensor: + + h = self.from_latent(x) + h = h.type(self.dtype) + subs = [] + for i, res in enumerate(self.blocks): + for j, block in enumerate(res): + if i < len(self.blocks) - 1 and j == len(res) - 1: + if self.pred_subdiv: + h, sub = block(h) + subs.append(sub) + else: + h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None) + else: + h = block(h) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.output_layer(h) + if return_subs: + return h, subs + else: + return h + + def upsample(self, x: SparseTensor, upsample_times: int) -> torch.Tensor: + + h = self.from_latent(x) + h = h.type(self.dtype) + for i, res in enumerate(self.blocks): + if i == upsample_times: + return h.coords + for j, block in enumerate(res): + if i < len(self.blocks) - 1 and j == len(res) - 1: + h, sub = block(h) + else: + h = block(h) + +class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder): + def __init__( + self, + resolution: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + up_block_type: List[str], + block_args: List[Dict[str, Any]], + voxel_margin: float = 0.5, + use_fp16: bool = False, + ): + self.resolution = resolution + self.voxel_margin = voxel_margin + # cache for a TorchHashMap instance + self._torch_hashmap_cache = None + + super().__init__( + 7, + model_channels, + latent_channels, + num_blocks, + block_type, + up_block_type, + block_args, + use_fp16, + ) + + def set_resolution(self, resolution: int) -> None: + self.resolution = resolution + + def _build_or_get_hashmap(self, coords: torch.Tensor, grid_size: torch.Tensor): + device = coords.device + N = coords.shape[0] + # compute flat keys for all coords (prepend batch 0 same as original code) + b = torch.zeros((N,), dtype=torch.long, device=device) + x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long() + W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) + flat_keys = b * (W * H * D) + x * (H * D) + y * D + z + values = torch.arange(N, dtype=torch.long, device=device) + DEFAULT_VAL = 0xffffffff # sentinel used in original code + return TorchHashMap(flat_keys, values, DEFAULT_VAL) + + def forward(self, x: SparseTensor, gt_intersected: SparseTensor = None, **kwargs): + decoded = super().forward(x, **kwargs) + out_list = list(decoded) if isinstance(decoded, tuple) else [decoded] + h = out_list[0] + vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin) + intersected = h.replace(h.feats[..., 3:6] > 0) + quad_lerp = h.replace(F.softplus(h.feats[..., 6:7])) + mesh = [Mesh(*flexible_dual_grid_to_mesh( + v.coords[:, 1:], v.feats, i.feats, q.feats, + aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + grid_size=self.resolution, + train=False, + hashmap_builder=self._build_or_get_hashmap, + )) for v, i, q in zip(vertices, intersected, quad_lerp)] + out_list[0] = mesh + return out_list[0] if len(out_list) == 1 else tuple(out_list) + +def flexible_dual_grid_to_mesh( + coords: torch.Tensor, + dual_vertices: torch.Tensor, + intersected_flag: torch.Tensor, + split_weight: Union[torch.Tensor, None], + aabb: Union[list, tuple, np.ndarray, torch.Tensor], + voxel_size: Union[float, list, tuple, np.ndarray, torch.Tensor] = None, + grid_size: Union[int, list, tuple, np.ndarray, torch.Tensor] = None, + train: bool = False, + hashmap_builder=None, # optional callable for building/caching a TorchHashMap +): + + if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset"): + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset = torch.tensor([ + [[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis + [[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis + [[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis + ], dtype=torch.int, device=coords.device).unsqueeze(0) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1"): + flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=coords.device, requires_grad=False) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2"): + flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=coords.device, requires_grad=False) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train"): + flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=coords.device, requires_grad=False) + + # AABB + if isinstance(aabb, (list, tuple)): + aabb = np.array(aabb) + if isinstance(aabb, np.ndarray): + aabb = torch.tensor(aabb, dtype=torch.float32, device=coords.device) + + # Voxel size + if voxel_size is not None: + if isinstance(voxel_size, float): + voxel_size = [voxel_size, voxel_size, voxel_size] + if isinstance(voxel_size, (list, tuple)): + voxel_size = np.array(voxel_size) + if isinstance(voxel_size, np.ndarray): + voxel_size = torch.tensor(voxel_size, dtype=torch.float32, device=coords.device) + grid_size = ((aabb[1] - aabb[0]) / voxel_size).round().int() + else: + if isinstance(grid_size, int): + grid_size = [grid_size, grid_size, grid_size] + if isinstance(grid_size, (list, tuple)): + grid_size = np.array(grid_size) + if isinstance(grid_size, np.ndarray): + grid_size = torch.tensor(grid_size, dtype=torch.int32, device=coords.device) + voxel_size = (aabb[1] - aabb[0]) / grid_size + + # Extract mesh + N = dual_vertices.shape[0] + mesh_vertices = (coords.float() + dual_vertices) / (2 * N) - 0.5 + + if hashmap_builder is None: + # build local TorchHashMap + device = coords.device + b = torch.zeros((N,), dtype=torch.long, device=device) + x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long() + W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) + flat_keys = b * (W * H * D) + x * (H * D) + y * D + z + values = torch.arange(N, dtype=torch.long, device=device) + DEFAULT_VAL = 0xffffffff + torch_hashmap = TorchHashMap(flat_keys, values, DEFAULT_VAL) + else: + torch_hashmap = hashmap_builder(coords, grid_size) + + # Find connected voxels + edge_neighbor_voxel = coords.reshape(N, 1, 1, 3) + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset # (N, 3, 4, 3) + connected_voxel = edge_neighbor_voxel[intersected_flag] # (M, 4, 3) + M = connected_voxel.shape[0] + # flatten connected voxel coords and lookup + conn_flat_b = torch.zeros((M * 4,), dtype=torch.long, device=coords.device) + conn_x = connected_voxel.reshape(-1, 3)[:, 0].long() + conn_y = connected_voxel.reshape(-1, 3)[:, 1].long() + conn_z = connected_voxel.reshape(-1, 3)[:, 2].long() + W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) + conn_flat = conn_flat_b * (W * H * D) + conn_x * (H * D) + conn_y * D + conn_z + + conn_indices = torch_hashmap.lookup_flat(conn_flat).reshape(M, 4).int() + connected_voxel_valid = (conn_indices != 0xffffffff).all(dim=1) + quad_indices = conn_indices[connected_voxel_valid].int() # (L, 4) + + mesh_vertices = (coords.float() + dual_vertices) * voxel_size + aabb[0].reshape(1, 3) + if split_weight is None: + # if split 1 + atempt_triangles_0 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1] + normals0 = torch.cross(mesh_vertices[atempt_triangles_0[:, 1]] - mesh_vertices[atempt_triangles_0[:, 0]], mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 0]]) + normals1 = torch.cross(mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 1]], mesh_vertices[atempt_triangles_0[:, 3]] - mesh_vertices[atempt_triangles_0[:, 1]]) + align0 = (normals0 * normals1).sum(dim=1, keepdim=True).abs() + # if split 2 + atempt_triangles_1 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2] + normals0 = torch.cross(mesh_vertices[atempt_triangles_1[:, 1]] - mesh_vertices[atempt_triangles_1[:, 0]], mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 0]]) + normals1 = torch.cross(mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 1]], mesh_vertices[atempt_triangles_1[:, 3]] - mesh_vertices[atempt_triangles_1[:, 1]]) + align1 = (normals0 * normals1).sum(dim=1, keepdim=True).abs() + # select split + mesh_triangles = torch.where(align0 > align1, atempt_triangles_0, atempt_triangles_1).reshape(-1, 3) + else: + split_weight_ws = split_weight[quad_indices] + split_weight_ws_02 = split_weight_ws[:, 0] * split_weight_ws[:, 2] + split_weight_ws_13 = split_weight_ws[:, 1] * split_weight_ws[:, 3] + mesh_triangles = torch.where( + split_weight_ws_02 > split_weight_ws_13, + quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1], + quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2] + ).reshape(-1, 3) + + return mesh_vertices, mesh_triangles + +class Vae(nn.Module): + def __init__(self, config, operations=None): + operations = operations or torch.nn + + self.txt_dec = SparseUnetVaeDecoder( + out_channels=6, + model_channels=[1024, 512, 256, 128, 64], + latent_channels=32, + num_blocks=[4, 16, 8, 4, 0], + block_type=["SparseConvNeXtBlock3d"] * 5, + up_block_type=["SparseResBlockS2C3d"] * 4, + pred_subdiv=False + ) + + self.shape_dec = FlexiDualGridVaeDecoder( + resolution=256, + model_channels=[1024, 512, 256, 128, 64], + latent_channels=32, + num_blocks=[4, 16, 8, 4, 0], + block_type=["SparseConvNeXtBlock3d"] * 5, + up_block_type=["SparseResBlockS2C3d"] * 4, + ) + + def decode_shape_slat(self, slat, resolution: int): + self.shape_dec.set_resolution(resolution) + return self.shape_dec(slat, return_subs=True) + + def decode_tex_slat(self, slat, subs): + return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5 + + @torch.no_grad() + def decode( + self, + shape_slat: SparseTensor, + tex_slat: SparseTensor, + resolution: int, + ): + meshes, subs = self.decode_shape_slat(shape_slat, resolution) + tex_voxels = self.decode_tex_slat(tex_slat, subs) + out_mesh = [] + for m, v in zip(meshes, tex_voxels): + m.fill_holes() # TODO + out_mesh.append( + MeshWithVoxel( + m.vertices, m.faces, + origin = [-0.5, -0.5, -0.5], + voxel_size = 1 / resolution, + coords = v.coords[:, 1:], + attrs = v.feats, + voxel_shape = torch.Size([*v.shape, *v.spatial_shape]), + layout=self.pbr_attr_layout + ) + ) + return out_mesh diff --git a/comfy_extras/trellis2.py b/comfy_extras/trellis2.py new file mode 100644 index 000000000..c3ad56007 --- /dev/null +++ b/comfy_extras/trellis2.py @@ -0,0 +1,240 @@ +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO +import torch +from comfy.ldm.trellis2.model import SparseTensor +import comfy.model_management +from PIL import Image +import PIL +import numpy as np + +shape_slat_normalization = { + "mean": torch.tensor([ + 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218, + -0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944, + 0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667, + -0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665 + ])[None], + "std": torch.tensor([ + 5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004, + 5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578, + 4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194, + 5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691 + ])[None] +} + +tex_slat_normalization = { + "mean": torch.tensor([ + 3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075, + 0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149, + -1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717, + 1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862 + ])[None], + "std": torch.tensor([ + 2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822, + 2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588, + 2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999, + 2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190 + ])[None] +} + +def smart_crop_square( + image: torch.Tensor, + background_color=(128, 128, 128), +): + C, H, W = image.shape + size = max(H, W) + canvas = torch.empty( + (C, size, size), + dtype=image.dtype, + device=image.device + ) + for c in range(C): + canvas[c].fill_(background_color[c]) + top = (size - H) // 2 + left = (size - W) // 2 + canvas[:, top:top + H, left:left + W] = image + + return canvas + +def run_conditioning( + model, + image: torch.Tensor, + include_1024: bool = True, + background_color: str = "black", +): + # TODO: should check if normalization was applied in these steps + model = model.model + device = comfy.model_management.intermediate_device() # replaces .cpu() + torch_device = comfy.model_management.get_torch_device() # replaces .cuda() + bg_colors = { + "black": (0, 0, 0), + "gray": (128, 128, 128), + "white": (255, 255, 255), + } + bg_color = bg_colors.get(background_color, (128, 128, 128)) + + # Convert image to PIL + if image.dim() == 4: + pil_image = (image[0] * 255).clip(0, 255).astype(torch.uint8) + else: + pil_image = (image * 255).clip(0, 255).astype(torch.uint8) + + pil_image = smart_crop_square(pil_image, background_color=bg_color) + + model.image_size = 512 + def set_image_size(image, image_size=512): + image = PIL.from_array(image) + image = [i.resize((image_size, image_size), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).to(torch_device) + return image + + pil_image = set_image_size(image, 512) + cond_512 = model([pil_image]) + + cond_1024 = None + if include_1024: + model.image_size = 1024 + pil_image = set_image_size(pil_image, 1024) + cond_1024 = model([pil_image]) + + neg_cond = torch.zeros_like(cond_512) + + conditioning = { + 'cond_512': cond_512.to(device), + 'neg_cond': neg_cond.to(device), + } + if cond_1024 is not None: + conditioning['cond_1024'] = cond_1024.to(device) + + preprocessed_tensor = pil_image.to(torch.float32) / 255.0 + preprocessed_tensor = torch.from_numpy(preprocessed_tensor).unsqueeze(0) + + return conditioning, preprocessed_tensor + +class VaeDecodeShapeTrellis(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VaeDecodeShapeTrellis", + category="latent/3d", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + IO.Int.Input("resolution", tooltip="Shape Generation Resolution"), + ], + outputs=[ + IO.Mesh.Output("mesh"), + IO.AnyType.Output("shape_subs"), + ] + ) + + @classmethod + def execute(cls, samples, vae, resolution): + std = shape_slat_normalization["std"] + mean = shape_slat_normalization["mean"] + samples = samples * std + mean + + mesh, subs = vae.decode_shape_slat(resolution, samples) + return mesh, subs + +class VaeDecodeTextureTrellis(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VaeDecodeTextureTrellis", + category="latent/3d", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + IO.AnyType.Input("shape_subs"), + ], + outputs=[ + IO.Mesh.Output("mesh"), + ] + ) + + @classmethod + def execute(cls, samples, vae, shape_subs): + if shape_subs is None: + raise ValueError("Shape subs must be provided for texture generation") + + std = tex_slat_normalization["std"] + mean = tex_slat_normalization["mean"] + samples = samples * std + mean + + mesh = vae.decode_tex_slat(samples, shape_subs) + return mesh + +class Trellis2Conditioning(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Trellis2Conditioning", + category="conditioning/video_models", + inputs=[ + IO.ClipVision.Input("clip_vision_model"), + IO.Image.Input("image"), + IO.MultiCombo.Input("background_color", options=["black", "gray", "white"], default="black") + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ] + ) + + @classmethod + def execute(cls, clip_vision_model, image, background_color) -> IO.NodeOutput: + # could make 1024 an option + conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color) + embeds = conditioning["cond_1024"] # should add that + positive = [[conditioning["cond_512"], {embeds}]] + negative = [[conditioning["cond_neg"], {embeds}]] + return IO.NodeOutput(positive, negative) + +class EmptyLatentTrellis2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentTrellis2", + category="latent/3d", + inputs=[ + IO.Int.Input("resolution", default=3072, min=1, max=8192), + IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), + IO.Vae.Input("vae"), + IO.Boolean.Input("shape_generation", tooltip="Setting to false will generate texture."), + IO.MultiCombo.Input("generation_type", options=["structure_generation", "shape_generation", "texture_generation"]) + ], + outputs=[ + IO.Latent.Output(), + ] + ) + + @classmethod + def execute(cls, batch_size, coords, vae, generation_type) -> IO.NodeOutput: + # TODO: i will probably update how shape/texture is generated + # could split this too + in_channels = 32 + shape_generation = generation_type == "shape_generation" + device = comfy.model_management.intermediate_device() + if shape_generation: + latent = SparseTensor(feats=torch.randn(batch_size, in_channels).to(device), coords=coords) + else: + # coords = shape_slat in txt gen case + latent = coords.replace(feats=torch.randn(coords.coords.shape[0], in_channels - coords.feats.shape[1]).to(device)) + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + +class Trellis2Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + Trellis2Conditioning, + EmptyLatentTrellis2, + VaeDecodeTextureTrellis, + VaeDecodeShapeTrellis + ] + + +async def comfy_entrypoint() -> Trellis2Extension: + return Trellis2Extension() From 23474ce816845341b2a79f9f31d5ff7df1709e85 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 2 Feb 2026 21:20:46 +0200 Subject: [PATCH 02/59] updated the trellis2 nodes --- .../{trellis2.py => nodes_trellis2.py} | 66 ++++++++++++++----- nodes.py | 3 +- 2 files changed, 51 insertions(+), 18 deletions(-) rename comfy_extras/{trellis2.py => nodes_trellis2.py} (82%) diff --git a/comfy_extras/trellis2.py b/comfy_extras/nodes_trellis2.py similarity index 82% rename from comfy_extras/trellis2.py rename to comfy_extras/nodes_trellis2.py index c3ad56007..304d95493 100644 --- a/comfy_extras/trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -193,7 +193,48 @@ class Trellis2Conditioning(IO.ComfyNode): negative = [[conditioning["cond_neg"], {embeds}]] return IO.NodeOutput(positive, negative) -class EmptyLatentTrellis2(IO.ComfyNode): +class EmptyShapeLatentTrellis2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentTrellis2", + category="latent/3d", + inputs=[ + IO.Latent.Input("structure_output"), + ], + outputs=[ + IO.Latent.Output(), + ] + ) + + def execute(cls, structure_output): + # i will see what i have to do here + coords = structure_output or structure_output.coords + in_channels = 32 + latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + +class EmptyTextureLatentTrellis2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentTrellis2", + category="latent/3d", + inputs=[ + IO.Latent.Input("structure_output"), + ], + outputs=[ + IO.Latent.Output(), + ] + ) + + def execute(cls, structure_output): + # TODO + in_channels = 32 + latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1])) + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + +class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( @@ -202,35 +243,26 @@ class EmptyLatentTrellis2(IO.ComfyNode): inputs=[ IO.Int.Input("resolution", default=3072, min=1, max=8192), IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), - IO.Vae.Input("vae"), - IO.Boolean.Input("shape_generation", tooltip="Setting to false will generate texture."), - IO.MultiCombo.Input("generation_type", options=["structure_generation", "shape_generation", "texture_generation"]) ], outputs=[ IO.Latent.Output(), ] ) - - @classmethod - def execute(cls, batch_size, coords, vae, generation_type) -> IO.NodeOutput: - # TODO: i will probably update how shape/texture is generated - # could split this too + + def execute(cls, res, batch_size): in_channels = 32 - shape_generation = generation_type == "shape_generation" - device = comfy.model_management.intermediate_device() - if shape_generation: - latent = SparseTensor(feats=torch.randn(batch_size, in_channels).to(device), coords=coords) - else: - # coords = shape_slat in txt gen case - latent = coords.replace(feats=torch.randn(coords.coords.shape[0], in_channels - coords.feats.shape[1]).to(device)) + latent = torch.randn(batch_size, in_channels, res, res, res) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + class Trellis2Extension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ Trellis2Conditioning, - EmptyLatentTrellis2, + EmptyShapeLatentTrellis2, + EmptyStructureLatentTrellis2, + EmptyTextureLatentTrellis2, VaeDecodeTextureTrellis, VaeDecodeShapeTrellis ] diff --git a/nodes.py b/nodes.py index 1cb43d9e2..051e808cc 100644 --- a/nodes.py +++ b/nodes.py @@ -2433,7 +2433,8 @@ async def init_builtin_extra_nodes(): "nodes_image_compare.py", "nodes_zimage.py", "nodes_lora_debug.py", - "nodes_color.py" + "nodes_color.py", + "nodes_trellis2.py" ] import_failed = [] From 614b167994aa847fda2b4a718048a2cdbc84c132 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 2 Feb 2026 21:23:19 +0200 Subject: [PATCH 03/59] . --- comfy_extras/nodes_trellis2.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 304d95493..70bbbb29d 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -206,7 +206,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): IO.Latent.Output(), ] ) - + + @classmethod def execute(cls, structure_output): # i will see what i have to do here coords = structure_output or structure_output.coords @@ -227,7 +228,8 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): IO.Latent.Output(), ] ) - + + @classmethod def execute(cls, structure_output): # TODO in_channels = 32 @@ -248,7 +250,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): IO.Latent.Output(), ] ) - + @classmethod def execute(cls, res, batch_size): in_channels = 32 latent = torch.randn(batch_size, in_channels, res, res, res) From f76e3a11b5dc4c09f2edb087c013ba7bcc5c7a6b Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 2 Feb 2026 21:27:15 +0200 Subject: [PATCH 04/59] .. --- comfy/ldm/trellis2/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index a0889c4dd..cdbfbf6fc 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -1,9 +1,9 @@ import torch import torch.nn.functional as F import torch.nn as nn -from vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor +from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor from typing import Optional, Tuple, Literal, Union, List -from attention import sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention +from comfy.ldm.trellis2.attention import sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder class SparseGELU(nn.GELU): From d6573fd26d63e6fe00515e903d4c2535b02fbeea Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:10:20 +0200 Subject: [PATCH 05/59] model init working --- comfy/latent_formats.py | 2 + comfy/ldm/trellis2/cumesh.py | 387 ++++++++++++++++++++++++++++++++++- comfy/ldm/trellis2/model.py | 20 +- comfy/ldm/trellis2/vae.py | 221 +++++++++++++++++++- comfy/model_base.py | 8 + comfy/supported_models.py | 13 +- 6 files changed, 639 insertions(+), 12 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 4b3a3798c..fc4c4e6d3 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -746,6 +746,8 @@ class Hunyuan3Dv2_1(LatentFormat): latent_channels = 64 latent_dimensions = 1 +class Trellis2(LatentFormat): # TODO + latent_channels = 32 class Hunyuan3Dv2mini(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index be8200341..41ac35db9 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -1,8 +1,192 @@ # will contain every cuda -> pytorch operation +import math import torch -from typing import Dict +from typing import Dict, Callable +NO_TRITION = False +try: + import triton + import triton.language as tl + heuristics = { + 'valid_kernel': lambda args: args['valid_kernel'](args['B1']), + 'valid_kernel_seg': lambda args: args['valid_kernel_seg'](args['B1']), + } + + #@triton_autotune( + # configs=config.autotune_config, + # key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'], + #) + @triton.heuristics(heuristics) + @triton.jit + def sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel( + input, + weight, + bias, + neighbor, + sorted_idx, + output, + # Tensor dimensions + N, LOGN, Ci, Co, V: tl.constexpr, + # Meta-parameters + B1: tl.constexpr, # Block size for N dimension + B2: tl.constexpr, # Block size for Co dimension + BK: tl.constexpr, # Block size for K dimension (V * Ci) + allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls + # Huristic parameters + valid_kernel, + valid_kernel_seg, + ): + + block_id = tl.program_id(axis=0) + block_dim_co = tl.cdiv(Co, B2) + block_id_co = block_id % block_dim_co + block_id_n = block_id // block_dim_co + + # Create pointers for submatrices of A and B. + num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension + valid_kernel_start = tl.load(valid_kernel_seg + block_id_n) + valid_kernel_seglen = tl.load(valid_kernel_seg + block_id_n + 1) - valid_kernel_start + offset_n = block_id_n * B1 + tl.arange(0, B1) + n_mask = offset_n < N + offset_sorted_n = tl.load(sorted_idx + offset_n, mask=n_mask, other=0) # (B1,) + offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,) + offset_k = tl.arange(0, BK) # (BK,) + + # Create a block of the output matrix C. + accumulator = tl.zeros((B1, B2), dtype=tl.float32) + + # Iterate along V*Ci dimension. + for k in range(num_k * valid_kernel_seglen): + v = k // num_k + bk = k % num_k + v = tl.load(valid_kernel + valid_kernel_start + v) + # Calculate pointers to input matrix. + neighbor_offset_n = tl.load(neighbor + offset_sorted_n * V + v) # (B1,) + input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK) + # Calculate pointers to weight matrix. + weight_ptr = weight + v * Ci + bk * BK + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2) + # Load the next block of input and weight. + neigh_mask = neighbor_offset_n != 0xffffffff + k_mask = offset_k < Ci - bk * BK + input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0) + weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0) + # Accumulate along the K dimension. + accumulator = tl.dot(input_block, weight_block, accumulator, + input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2) + c = accumulator.to(input.type.element_ty) + + # add bias + if bias is not None: + bias_block = tl.load(bias + offset_co) + c += bias_block[None, :] + + # Write back the block of the output matrix with masks. + out_offset_n = offset_sorted_n + out_offset_co = block_id_co * B2 + tl.arange(0, B2) + out_ptr = output + (out_offset_n[:, None] * Co + out_offset_co[None, :]) + out_mask = n_mask[:, None] & (out_offset_co[None, :] < Co) + tl.store(out_ptr, c, mask=out_mask) + def sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + neighbor: torch.Tensor, + sorted_idx: torch.Tensor, + valid_kernel: Callable[[int], torch.Tensor], + valid_kernel_seg: Callable[[int], torch.Tensor], + ) -> torch.Tensor: + N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1] + LOGN = int(math.log2(N)) + output = torch.empty((N, Co), device=input.device, dtype=input.dtype) + grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),) + sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid]( + input, weight, bias, neighbor, sorted_idx, output, + N, LOGN, Ci, Co, V, # + valid_kernel=valid_kernel, + valid_kernel_seg=valid_kernel_seg, + allow_tf32=torch.cuda.is_tf32_supported(), + ) + return output +except: + NO_TRITION = True + +def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device): + # offsets in same order as CUDA kernel + offsets = [] + for vx in range(Kw): + for vy in range(Kh): + for vz in range(Kd): + offsets.append(( + vx * Dw, + vy * Dh, + vz * Dd + )) + return torch.tensor(offsets, device=device) + +def build_submanifold_neighbor_map( + hashmap, + coords: torch.Tensor, + W, H, D, + Kw, Kh, Kd, + Dw, Dh, Dd, +): + device = coords.device + M = coords.shape[0] + V = Kw * Kh * Kd + half_V = V // 2 + 1 + + INVALID = hashmap.default_value + + neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long) + + b = coords[:, 0] + x = coords[:, 1] + y = coords[:, 2] + z = coords[:, 3] + + offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device) + + ox = x[:, None] - (Kw // 2) * Dw + oy = y[:, None] - (Kh // 2) * Dh + oz = z[:, None] - (Kd // 2) * Dd + + for v in range(half_V): + if v == half_V - 1: + neighbor[:, v] = torch.arange(M, device=device) + continue + + dx, dy, dz = offsets[v] + + kx = ox[:, v] + dx + ky = oy[:, v] + dy + kz = oz[:, v] + dz + + valid = ( + (kx >= 0) & (kx < W) & + (ky >= 0) & (ky < H) & + (kz >= 0) & (kz < D) + ) + + flat = ( + b * (W * H * D) + + kx * (H * D) + + ky * D + + kz + ) + + flat = flat[valid] + idx = torch.nonzero(valid, as_tuple=False).squeeze(1) + + found = hashmap.lookup_flat(flat) + + neighbor[idx, v] = found + + # symmetric write + valid_found = found != INVALID + neighbor[found[valid_found], V - 1 - v] = idx[valid_found] + + return neighbor class TorchHashMap: def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): @@ -22,6 +206,207 @@ class TorchHashMap: out[found] = self.sorted_vals[idx[found]] return out + +UINT32_SENTINEL = 0xFFFFFFFF + +def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map): + device = neighbor_map.device + N, V = neighbor_map.shape + + + neigh = neighbor_map.to(torch.long) + sentinel = torch.tensor(UINT32_SENTINEL, dtype=torch.long, device=device) + + + neigh_map_T = neigh.t().reshape(-1) + + neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32) + + mask = (neigh != sentinel).to(torch.long) + + powers = (1 << torch.arange(V, dtype=torch.long, device=device)) + + gray_long = (mask * powers).sum(dim=1) + + gray_code = gray_long.to(torch.int32) + + binary_long = gray_long.clone() + for v in range(1, V): + binary_long ^= (gray_long >> v) + binary_code = binary_long.to(torch.int32) + + sorted_idx = torch.argsort(binary_code) + + prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T.to(torch.int32), dim=0) # (V*N,) + + total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0 + + if total_valid_signal > 0: + valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device) + valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device) + + pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0] + + to = (prefix_sum_neighbor_mask[pos] - 1).to(torch.long) + + valid_signal_i[to] = (pos % N).to(torch.long) + + valid_signal_o[to] = neigh_map_T[pos].to(torch.long) + else: + valid_signal_i = torch.empty((0,), dtype=torch.long, device=device) + valid_signal_o = torch.empty((0,), dtype=torch.long, device=device) + + seg = torch.empty((V + 1,), dtype=torch.long, device=device) + seg[0] = 0 + if V > 0: + idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1 + seg[1:] = prefix_sum_neighbor_mask[idxs].to(torch.long) + else: + pass + + return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg + +def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor: + + x = x.to(torch.int64) + + m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device) + m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device) + m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device) + h01 = torch.tensor(0x0101010101010101, dtype=torch.int64, device=x.device) + + x = x - ((x >> 1) & m1) + x = (x & m2) + ((x >> 2) & m2) + x = (x + (x >> 4)) & m4 + x = (x * h01) >> 56 + return x.to(torch.int32) + + +def neighbor_map_post_process_for_masked_implicit_gemm_2( + gray_code: torch.Tensor, # [N], int32-like (non-negative) + sorted_idx: torch.Tensor, # [N], long (indexing into gray_code) + block_size: int +): + device = gray_code.device + N = gray_code.numel() + + # num of blocks (same as CUDA) + num_blocks = (N + block_size - 1) // block_size + + # Ensure dtypes + gray_long = gray_code.to(torch.int64) # safer to OR in 64-bit then cast + sorted_idx = sorted_idx.to(torch.long) + + # 1) Group gray_code by blocks and compute OR across each block + # pad the last block with zeros if necessary so we can reshape + pad = num_blocks * block_size - N + if pad > 0: + pad_vals = torch.zeros((pad,), dtype=torch.int64, device=device) + gray_padded = torch.cat([gray_long[sorted_idx], pad_vals], dim=0) + else: + gray_padded = gray_long[sorted_idx] + + # reshape to (num_blocks, block_size) and compute bitwise_or across dim=1 + gray_blocks = gray_padded.view(num_blocks, block_size) # each row = block entries + # reduce with bitwise_or + reduced_code = gray_blocks[:, 0].clone() + for i in range(1, block_size): + reduced_code |= gray_blocks[:, i] + reduced_code = reduced_code.to(torch.int32) # match CUDA int32 + + # 2) compute seglen (popcount per reduced_code) and seg (prefix sum) + seglen_counts = _popcount_int32_tensor(reduced_code.to(torch.int64)).to(torch.int32) # [num_blocks] + # seg: length num_blocks+1, seg[0]=0, seg[i+1]=cumsum(seglen_counts) up to i + seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device) + seg[0] = 0 + if num_blocks > 0: + seg[1:] = torch.cumsum(seglen_counts, dim=0) + + total = int(seg[-1].item()) + + # 3) scatter — produce valid_kernel_idx as concatenated ascending set-bit positions for each reduced_code row + if total == 0: + valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device) + return valid_kernel_idx, seg + + max_val = int(reduced_code.max().item()) + V = max_val.bit_length() if max_val > 0 else 0 + # If you know V externally, pass it instead or set here explicitly. + + if V == 0: + # no bits set anywhere + valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device) + return valid_kernel_idx, seg + + # build mask of shape (num_blocks, V): True where bit is set + bit_pos = torch.arange(0, V, dtype=torch.int64, device=device) # [V] + # shifted = reduced_code[:, None] >> bit_pos[None, :] + shifted = reduced_code.to(torch.int64).unsqueeze(1) >> bit_pos.unsqueeze(0) + bits = (shifted & 1).to(torch.bool) # (num_blocks, V) + + positions = bit_pos.unsqueeze(0).expand(num_blocks, V) + + valid_positions = positions[bits] + valid_kernel_idx = valid_positions.to(torch.int32).contiguous() + + return valid_kernel_idx, seg + + +def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation): + if len(shape) == 5: + N, C, W, H, D = shape + else: + W, H, D = shape + + Co, Kw, Kh, Kd, Ci = weight.shape + + b_stride = W * H * D + x_stride = H * D + y_stride = D + z_stride = 1 + + flat_keys = (coords[:, 0].long() * b_stride + + coords[:, 1].long() * x_stride + + coords[:, 2].long() * y_stride + + coords[:, 3].long() * z_stride) + + vals = torch.arange(coords.shape[0], dtype=torch.int32, device=coords.device) + + hashmap = TorchHashMap(flat_keys, vals, 0xFFFFFFFF) + + if neighbor_cache is None: + neighbor = build_submanifold_neighbor_map( + hashmap, coords, W, H, D, Kw, Kh, Kd, + dilation[0], dilation[1], dilation[2] + ) + else: + neighbor = neighbor_cache + + block_size = 128 + + gray_code, sorted_idx, valid_signal_i, valid_signal_o, valid_signal_seg = \ + neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor) + + valid_kernel, valid_kernel_seg = \ + neighbor_map_post_process_for_masked_implicit_gemm_2(gray_code, sorted_idx, block_size) + + valid_kernel_fn = lambda b_size: valid_kernel + valid_kernel_seg_fn = lambda b_size: valid_kernel_seg + + weight_flat = weight.contiguous().view(Co, -1, Ci) + + out = sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk( + feats, + weight_flat, + bias, + neighbor, + sorted_idx, + valid_kernel_fn, + valid_kernel_seg_fn + ) + + return out, neighbor + class Voxel: def __init__( self, diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index cdbfbf6fc..9d1a8fdb4 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -408,7 +408,7 @@ class SLatFlowModel(nn.Module): self.qk_rms_norm_cross = qk_rms_norm_cross self.dtype = dtype - self.t_embedder = TimestepEmbedder(model_channels) + self.t_embedder = TimestepEmbedder(model_channels, device=device, dtype=dtype, operations=operations) if share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), @@ -485,15 +485,25 @@ class Trellis2(nn.Module): qk_rms_norm = True, qk_rms_norm_cross = True, dtype=None, device=None, operations=None): + + super().__init__() args = { "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations } # TODO: update the names/checkpoints - self.img2shape = SLatFlowModel(resolution, in_channels=in_channels, *args) - self.shape2txt = SLatFlowModel(resolution, in_channels=in_channels*2, *args) + self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) + self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) self.shape_generation = True - def forward(self, x, timestep, context): - pass + def forward(self, x, timestep, context, **kwargs): + # TODO add mode + mode = kwargs.get("mode", "shape_generation") + mode = "texture_generation" if mode == 1 else "shape_generation" + if mode == "shape_generation": + out = self.img2shape(x, timestep, context) + if mode == "texture_generation": + out = self.shape2txt(x, timestep, context) + + return out diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 1d564bca2..5dabf5246 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -1,3 +1,4 @@ +import math import torch import torch.nn as nn from typing import List, Any, Dict, Optional, overload, Union, Tuple @@ -5,12 +6,219 @@ from fractions import Fraction import torch.nn.functional as F from dataclasses import dataclass import numpy as np -from cumesh import TorchHashMap, Mesh, MeshWithVoxel +from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d + + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key) + + def forward(self, x): + return sparse_conv3d_forward(self, x) + + +def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size, ) * 3 + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, ) * 3 + self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation, ) * 3 + + self.weight = nn.Parameter(torch.empty((out_channels, in_channels, *self.kernel_size))) + if bias: + self.bias = nn.Parameter(torch.empty(out_channels)) + else: + self.register_parameter("bias", None) + + if self.bias is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + torch.nn.init.uniform_(self.bias, -bound, bound) + + # Permute weight (Co, Ci, Kd, Kh, Kw) -> (Co, Kd, Kh, Kw, Ci) + self.weight = nn.Parameter(self.weight.permute(0, 2, 3, 4, 1).contiguous()) + + +def sparse_conv3d_forward(self, x): + # check if neighbor map is already computed + Co, Kd, Kh, Kw, Ci = self.weight.shape + neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}' + neighbor_cache = x.get_spatial_cache(neighbor_cache_key) + + out, neighbor_cache_ = sparse_submanifold_conv3d( + x.feats, + x.coords, + torch.Size([*x.shape, *x.spatial_shape]), + self.weight, + self.bias, + neighbor_cache, + self.dilation + ) + + if neighbor_cache is None: + x.register_spatial_cache(neighbor_cache_key, neighbor_cache_) + + out = x.replace(out) + return out + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = x.to(torch.float32) + o = super().forward(x) + return o.to(x_dtype) + +class SparseConvNeXtBlock3d(nn.Module): + def __init__( + self, + channels: int, + mlp_ratio: float = 4.0, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.use_checkpoint = use_checkpoint + + self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.conv = SparseConv3d(channels, channels, 3) + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.SiLU(), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def _forward(self, x): + h = self.conv(x) + h = h.replace(self.norm(h.feats)) + h = h.replace(self.mlp(h.feats)) + return h + x + + def forward(self, x): + return self._forward(x) + +class SparseSpatial2Channel(nn.Module): + def __init__(self, factor: int = 2): + super(SparseSpatial2Channel, self).__init__() + self.factor = factor + + def forward(self, x): + DIM = x.coords.shape[-1] - 1 + cache = x.get_spatial_cache(f'spatial2channel_{self.factor}') + if cache is None: + coord = list(x.coords.unbind(dim=-1)) + for i in range(DIM): + coord[i+1] = coord[i+1] // self.factor + subidx = x.coords[:, 1:] % self.factor + subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)]) + + MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + else: + new_coords, idx, subidx = cache + + new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype) + new_feats[idx * self.factor ** DIM + subidx] = x.feats + + out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM])) + out._scale = tuple([s * self.factor for s in x._scale]) + out._spatial_cache = x._spatial_cache + + if cache is None: + x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx)) + out.register_spatial_cache(f'channel2spatial_{self.factor}', (x.coords, idx, subidx)) + out.register_spatial_cache('shape', torch.Size(MAX)) + + return out + +class SparseChannel2Spatial(nn.Module): + def __init__(self, factor: int = 2): + super(SparseChannel2Spatial, self).__init__() + self.factor = factor + + def forward(self, x, subdivision = None): + DIM = x.coords.shape[-1] - 1 + + cache = x.get_spatial_cache(f'channel2spatial_{self.factor}') + if cache is None: + if subdivision is None: + raise ValueError('Cache not found. Provide subdivision tensor or pair SparseChannel2Spatial with SparseSpatial2Channel.') + else: + sub = subdivision.feats # [N, self.factor ** DIM] + N_leaf = sub.sum(dim=-1) # [N] + subidx = sub.nonzero()[:, -1] + new_coords = x.coords.clone().detach() + new_coords[:, 1:] *= self.factor + new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0]) + for i in range(DIM): + new_coords[:, i+1] += subidx // self.factor ** i % self.factor + idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0]) + else: + new_coords, idx, subidx = cache + + x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1) + new_feats = x_feats[idx * self.factor ** DIM + subidx] + out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM])) + out._scale = tuple([s / self.factor for s in x._scale]) + if cache is not None: # only keep cache when subdiv following it + out._spatial_cache = x._spatial_cache + return out + +class SparseResBlockC2S3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + self.pred_subdiv = pred_subdiv + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = SparseConv3d(channels, self.out_channels * 8, 3) + self.conv2 = SparseConv3d(self.out_channels, self.out_channels, 3) + self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1)) + if pred_subdiv: + self.to_subdiv = SparseLinear(channels, 8) + self.updown = SparseChannel2Spatial(2) + + def _forward(self, x, subdiv = None): + if self.pred_subdiv: + subdiv = self.to_subdiv(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None + h = self.updown(h, subdiv_binarized) + x = self.updown(x, subdiv_binarized) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + if self.pred_subdiv: + return h, subdiv + else: + return h -# TODO: determine which conv they actually use @dataclass class config: - CONV = "none" + CONV = "flexgemm" + FLEX_GEMM_HASHMAP_RATIO = 2.0 # TODO post processing def simplify(self, target_num_faces: int, verbose: bool=False, options: dict={}): @@ -1131,6 +1339,7 @@ def flexible_dual_grid_to_mesh( class Vae(nn.Module): def __init__(self, config, operations=None): + super().__init__() operations = operations or torch.nn self.txt_dec = SparseUnetVaeDecoder( @@ -1139,7 +1348,8 @@ class Vae(nn.Module): latent_channels=32, num_blocks=[4, 16, 8, 4, 0], block_type=["SparseConvNeXtBlock3d"] * 5, - up_block_type=["SparseResBlockS2C3d"] * 4, + up_block_type=["SparseResBlockC2S3d"] * 4, + block_args=[{}, {}, {}, {}, {}], pred_subdiv=False ) @@ -1149,7 +1359,8 @@ class Vae(nn.Module): latent_channels=32, num_blocks=[4, 16, 8, 4, 0], block_type=["SparseConvNeXtBlock3d"] * 5, - up_block_type=["SparseResBlockS2C3d"] * 4, + up_block_type=["SparseResBlockC2S3d"] * 4, + block_args=[{}, {}, {}, {}, {}], ) def decode_shape_slat(self, slat, resolution: int): diff --git a/comfy/model_base.py b/comfy/model_base.py index 85acdb66a..a5fc81c4d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -50,6 +50,7 @@ import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model +import comfy.ldm.trellis2.model import comfy.model_management import comfy.patcher_extension @@ -1455,6 +1456,13 @@ class WAN22(WAN21): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class Trellis2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None, unet_model=comfy.ldm.trellis2.model.Trellis2): + super().__init__(model_config, model_type, device, unet_model) + + def extra_conds(self, **kwargs): + return super().extra_conds(**kwargs) + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index d25271d6e..6c2725b9f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1242,6 +1242,17 @@ class WAN22_T2V(WAN21_T2V): out = model_base.WAN22(self, image_to_video=True, device=device) return out +class Trellis2(supported_models_base.BASE): + unet_config = { + "image_model": "trellis2" + } + + latent_format = latent_formats.Trellis2 + vae_key_prefix = ["vae."] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.Trellis2(self, device=device) + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1596,6 +1607,6 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, Trellis2] models += [SVD_img2vid] From 66249395056ab311f4459fdec138fc49e080702d Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 3 Feb 2026 22:40:54 +0200 Subject: [PATCH 06/59] structure model --- comfy/ldm/trellis2/attention.py | 67 +++++++ comfy/ldm/trellis2/model.py | 316 +++++++++++++++++++++++++++++++- comfy/supported_models.py | 4 + 3 files changed, 382 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 9cd7d4995..6c912c8d9 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -4,6 +4,73 @@ from comfy.ldm.modules.attention import optimized_attention from typing import Tuple, Union, List from vae import VarLenTensor +FLASH_ATTN_3_AVA = True +try: + import flash_attn_interface as flash_attn_3 +except: + FLASH_ATTN_3_AVA = False + +# TODO repalce with optimized attention +def scaled_dot_product_attention(*args, **kwargs): + num_all_args = len(args) + len(kwargs) + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + + if optimized_attention.__name__ == 'attention_xformers': + if 'xops' not in globals(): + import xformers.ops as xops + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = xops.memory_efficient_attention(q, k, v) + elif optimized_attention.__name__ == 'attention_flash' and not FLASH_ATTN_3_AVA: + if 'flash_attn' not in globals(): + import flash_attn + if num_all_args == 2: + out = flash_attn.flash_attn_kvpacked_func(q, kv) + elif num_all_args == 3: + out = flash_attn.flash_attn_func(q, k, v) + elif optimized_attention.__name__ == 'attention_flash': # TODO + if 'flash_attn_3' not in globals(): + import flash_attn_interface as flash_attn_3 + if num_all_args == 2: + k, v = kv.unbind(dim=2) + out = flash_attn_3.flash_attn_func(q, k, v) + elif num_all_args == 3: + out = flash_attn_3.flash_attn_func(q, k, v) + elif optimized_attention.__name__ == 'attention_pytorch': + if 'sdpa' not in globals(): + from torch.nn.functional import scaled_dot_product_attention as sdpa + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + elif optimized_attention.__name__ == 'attention_basic': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.shape[2] # TODO + out = optimized_attention(q, k, v) + + return out + def sparse_windowed_scaled_dot_product_self_attention( qkv, window_size: int, diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 9d1a8fdb4..add3f21ce 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -3,7 +3,9 @@ import torch.nn.functional as F import torch.nn as nn from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor from typing import Optional, Tuple, Literal, Union, List -from comfy.ldm.trellis2.attention import sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention +from comfy.ldm.trellis2.attention import ( + sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention +) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder class SparseGELU(nn.GELU): @@ -103,6 +105,18 @@ class SparseRotaryPositionEmbedder(nn.Module): k_embed = k.replace(self._rotary_embedding(k.feats, phases)) return q_embed, k_embed +class RotaryPositionEmbedder(SparseRotaryPositionEmbedder): + def forward(self, indices: torch.Tensor) -> torch.Tensor: + assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}" + phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], padn, device=phases.device), + torch.zeros(*phases.shape[:-1], padn, device=phases.device) + )], dim=-1) + return phases + class SparseMultiHeadAttention(nn.Module): def __init__( self, @@ -472,6 +486,292 @@ class SLatFlowModel(nn.Module): h = self.out_layer(h) return h +class FeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.GELU(approximate="tanh"), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int]=None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + ): + super().__init__() + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + B, L, C = x.shape + if self._type == "self": + qkv = self.to_qkv(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1) + + if self.attn_mode == "full": + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + assert phases is not None, "Phases must be provided for RoPE" + q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases) + k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(qkv) + else: + Lkv = context.shape[1] + q = self.to_q(x) + kv = self.to_kv(context) + q = q.reshape(B, L, self.num_heads, -1) + kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=2) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(q, kv) + h = h.reshape(B, L, -1) + h = self.to_out(h) + return h + +class ModulatedTransformerCrossBlock(nn.Module): + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.self_attn(h, phases=phases) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False) + else: + return self._forward(x, mod, context, phases) + + +class SparseStructureFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "ape", + rope_freq: Tuple[float, float] = (1.0, 10000.0), + dtype: str = 'float32', + use_checkpoint: bool = False, + share_mod: bool = False, + initialization: str = 'vanilla', + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + **kwargs + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.initialization = initialization + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = dtype + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + rope_phases = pos_embedder(coords) + self.register_buffer("rope_phases", rope_phases) + + if pe_mode != "rope": + self.rope_phases = None + + self.input_layer = nn.Linear(in_channels, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + rope_freq=rope_freq, + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = nn.Linear(model_channels, out_channels) + + self.initialize_weights() + self.convert_to(self.dtype) + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ + f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" + + h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous() + + h = self.input_layer(h) + if self.pe_mode == "ape": + h = h + self.pos_emb[None] + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = manual_cast(t_emb, self.dtype) + h = manual_cast(h, self.dtype) + cond = manual_cast(cond, self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond, self.rope_phases) + h = manual_cast(h, x.dtype) + h = F.layer_norm(h, h.shape[-1:]) + h = self.out_layer(h) + + h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous() + + return h + class Trellis2(nn.Module): def __init__(self, resolution, in_channels = 32, @@ -492,18 +792,24 @@ class Trellis2(nn.Module): "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations } - # TODO: update the names/checkpoints self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) - self.shape_generation = True + args.pop("out_channels") + args.pop("in_channels") + self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) def forward(self, x, timestep, context, **kwargs): # TODO add mode mode = kwargs.get("mode", "shape_generation") - mode = "texture_generation" if mode == 1 else "shape_generation" + if mode != 0: + mode = "texture_generation" if mode == 2 else "shape_generation" + else: + mode = "structure_generation" if mode == "shape_generation": out = self.img2shape(x, timestep, context) - if mode == "texture_generation": + elif mode == "texture_generation": out = self.shape2txt(x, timestep, context) + else: + out = self.structure_model(x, timestep, context) return out diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6c2725b9f..9e2f17149 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1247,6 +1247,10 @@ class Trellis2(supported_models_base.BASE): "image_model": "trellis2" } + sampling_settings = { + "shift": 3.0, + } + latent_format = latent_formats.Trellis2 vae_key_prefix = ["vae."] From 3002708fe390b1dd8b75a725cb5ae14ce8ceb7a3 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 4 Feb 2026 14:15:00 +0200 Subject: [PATCH 07/59] needed updates --- comfy/ldm/trellis2/model.py | 29 +++++-- comfy/ldm/trellis2/vae.py | 148 +++++++++++++++++++++++++++++++++ comfy/model_base.py | 5 +- comfy_extras/nodes_trellis2.py | 46 +++++++--- 4 files changed, 209 insertions(+), 19 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index add3f21ce..1dbbc4955 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -7,6 +7,7 @@ from comfy.ldm.trellis2.attention import ( sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention ) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder +from comfy.nested_tensor import NestedTensor class SparseGELU(nn.GELU): def forward(self, input: VarLenTensor) -> VarLenTensor: @@ -772,6 +773,11 @@ class SparseStructureFlowModel(nn.Module): return h +def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0): + t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1)) + t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear) + return t_new + class Trellis2(nn.Module): def __init__(self, resolution, in_channels = 32, @@ -798,18 +804,25 @@ class Trellis2(nn.Module): args.pop("in_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) - def forward(self, x, timestep, context, **kwargs): - # TODO add mode - mode = kwargs.get("mode", "shape_generation") - if mode != 0: - mode = "texture_generation" if mode == 2 else "shape_generation" - else: + def forward(self, x: NestedTensor, timestep, context, **kwargs): + x = x.tensors[0] + embeds = kwargs.get("embeds") + if not hasattr(x, "feats"): mode = "structure_generation" + else: + if x.feats.shape[1] == 32: + mode = "shape_generation" + else: + mode = "texture_generation" if mode == "shape_generation": - out = self.img2shape(x, timestep, context) + # TODO + out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)])) elif mode == "texture_generation": out = self.shape2txt(x, timestep, context) - else: + else: # structure + timestep = timestep_reshift(timestep) out = self.structure_model(x, timestep, context) + out = NestedTensor([out]) + out.generation_mode = mode return out diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 5dabf5246..584fa91ae 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -9,6 +9,17 @@ import numpy as np from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d +def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: + """ + 3D pixel shuffle. + """ + B, C, H, W, D = x.shape + C_ = C // scale_factor**3 + x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) + x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) + return x + class SparseConv3d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): super(SparseConv3d, self).__init__() @@ -1337,6 +1348,135 @@ def flexible_dual_grid_to_mesh( return mesh_vertices, mesh_triangles +class ChannelLayerNorm32(LayerNorm32): + def forward(self, x: torch.Tensor) -> torch.Tensor: + DIM = x.dim() + x = x.permute(0, *range(2, DIM), 1).contiguous() + x = super().forward(x) + x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() + return x + +class UpsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode = "conv", + ): + assert mode in ["conv", "nearest"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) + elif mode == "nearest": + assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + x = self.conv(x) + return pixel_shuffle_3d(x, 2) + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + +def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: + return ChannelLayerNorm32(*args, **kwargs) + +class ResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + norm_type = "layer", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.norm1 = norm_layer(norm_type, channels) + self.norm2 = norm_layer(norm_type, self.out_channels) + self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) + self.conv2 = nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1) + self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + +class SparseStructureDecoder(nn.Module): + def __init__( + self, + out_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.out_channels = out_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[0], channels[0]) + for _ in range(num_res_blocks_middle) + ]) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + UpsampleBlock3d(ch, channels[i+1]) + ) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], out_channels, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.input_layer(x) + + h = h.type(self.dtype) + + h = self.middle_block(h) + for block in self.blocks: + h = block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + return h + class Vae(nn.Module): def __init__(self, config, operations=None): super().__init__() @@ -1363,6 +1503,14 @@ class Vae(nn.Module): block_args=[{}, {}, {}, {}, {}], ) + self.struct_dec = SparseStructureDecoder( + out_channels=1, + latent_channels=8, + num_res_blocks=2, + num_res_blocks_middle=2, + channels=[512, 128, 32], + ) + def decode_shape_slat(self, slat, resolution: int): self.shape_dec.set_resolution(resolution) return self.shape_dec(slat, return_subs=True) diff --git a/comfy/model_base.py b/comfy/model_base.py index a5fc81c4d..6bf4fadc9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1461,7 +1461,10 @@ class Trellis2(BaseModel): super().__init__(model_config, model_type, device, unet_model) def extra_conds(self, **kwargs): - return super().extra_conds(**kwargs) + out = super().extra_conds(**kwargs) + embeds = kwargs.get("embeds") + out["embeds"] = comfy.conds.CONDRegular(embeds) + return out class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 70bbbb29d..4a36e2fee 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -6,6 +6,7 @@ import comfy.model_management from PIL import Image import PIL import numpy as np +from comfy.nested_tensor import NestedTensor shape_slat_normalization = { "mean": torch.tensor([ @@ -131,7 +132,8 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples, vae, resolution): + def execute(cls, samples: NestedTensor, vae, resolution): + samples = samples.tensors[0] std = shape_slat_normalization["std"] mean = shape_slat_normalization["mean"] samples = samples * std + mean @@ -157,9 +159,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, vae, shape_subs): - if shape_subs is None: - raise ValueError("Shape subs must be provided for texture generation") - + samples = samples.tensors[0] std = tex_slat_normalization["std"] mean = tex_slat_normalization["mean"] samples = samples * std + mean @@ -167,6 +167,28 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): mesh = vae.decode_tex_slat(samples, shape_subs) return mesh +class VaeDecodeStructureTrellis2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VaeDecodeStructureTrellis2", + category="latent/3d", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + ], + outputs=[ + IO.Mesh.Output("structure_output"), + ] + ) + + @classmethod + def execute(cls, samples, vae): + decoder = vae.struct_dec + decoded = decoder(samples)>0 + coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int() + return coords + class Trellis2Conditioning(IO.ComfyNode): @classmethod def define_schema(cls): @@ -189,8 +211,8 @@ class Trellis2Conditioning(IO.ComfyNode): # could make 1024 an option conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color) embeds = conditioning["cond_1024"] # should add that - positive = [[conditioning["cond_512"], {embeds}]] - negative = [[conditioning["cond_neg"], {embeds}]] + positive = [[conditioning["cond_512"], {"embeds": embeds}]] + negative = [[conditioning["cond_neg"], {"embeds": embeds}]] return IO.NodeOutput(positive, negative) class EmptyShapeLatentTrellis2(IO.ComfyNode): @@ -200,7 +222,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): node_id="EmptyLatentTrellis2", category="latent/3d", inputs=[ - IO.Latent.Input("structure_output"), + IO.Mesh.Input("structure_output"), ], outputs=[ IO.Latent.Output(), @@ -210,9 +232,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): # i will see what i have to do here - coords = structure_output or structure_output.coords + coords = structure_output # or structure_output.coords in_channels = 32 latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) + latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @@ -222,7 +245,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): node_id="EmptyLatentTrellis2", category="latent/3d", inputs=[ - IO.Latent.Input("structure_output"), + IO.Mesh.Input("structure_output"), ], outputs=[ IO.Latent.Output(), @@ -234,6 +257,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): # TODO in_channels = 32 latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1])) + latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @@ -254,6 +278,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): def execute(cls, res, batch_size): in_channels = 32 latent = torch.randn(batch_size, in_channels, res, res, res) + latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) @@ -266,7 +291,8 @@ class Trellis2Extension(ComfyExtension): EmptyStructureLatentTrellis2, EmptyTextureLatentTrellis2, VaeDecodeTextureTrellis, - VaeDecodeShapeTrellis + VaeDecodeShapeTrellis, + VaeDecodeStructureTrellis2 ] From cd0f7ba64e6d0a195fdc53132cc58b053569dacb Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 5 Feb 2026 02:34:08 +0200 Subject: [PATCH 08/59] apply rope and optimized attention --- comfy/ldm/trellis2/attention.py | 29 +++++++------ comfy/ldm/trellis2/model.py | 76 ++++++++++++++++----------------- 2 files changed, 52 insertions(+), 53 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 6c912c8d9..edc85ce83 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -26,21 +26,21 @@ def scaled_dot_product_attention(*args, **kwargs): k = args[1] if len(args) > 1 else kwargs['k'] v = args[2] if len(args) > 2 else kwargs['v'] + # TODO verify + heads = q or qkv + heads = heads.shape[2] + if optimized_attention.__name__ == 'attention_xformers': - if 'xops' not in globals(): - import xformers.ops as xops if num_all_args == 1: q, k, v = qkv.unbind(dim=2) elif num_all_args == 2: k, v = kv.unbind(dim=2) - out = xops.memory_efficient_attention(q, k, v) + #out = xops.memory_efficient_attention(q, k, v) + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) elif optimized_attention.__name__ == 'attention_flash' and not FLASH_ATTN_3_AVA: - if 'flash_attn' not in globals(): - import flash_attn if num_all_args == 2: - out = flash_attn.flash_attn_kvpacked_func(q, kv) - elif num_all_args == 3: - out = flash_attn.flash_attn_func(q, k, v) + k, v = kv.unbind(dim=2) + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) elif optimized_attention.__name__ == 'attention_flash': # TODO if 'flash_attn_3' not in globals(): import flash_attn_interface as flash_attn_3 @@ -59,15 +59,14 @@ def scaled_dot_product_attention(*args, **kwargs): q = q.permute(0, 2, 1, 3) # [N, H, L, C] k = k.permute(0, 2, 1, 3) # [N, H, L, C] v = v.permute(0, 2, 1, 3) # [N, H, L, C] - out = sdpa(q, k, v) # [N, H, L, C] + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) out = out.permute(0, 2, 1, 3) # [N, L, H, C] elif optimized_attention.__name__ == 'attention_basic': if num_all_args == 1: q, k, v = qkv.unbind(dim=2) elif num_all_args == 2: k, v = kv.unbind(dim=2) - q = q.shape[2] # TODO - out = optimized_attention(q, k, v) + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) return out @@ -86,19 +85,21 @@ def sparse_windowed_scaled_dot_product_self_attention( fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + heads = qkv_feats.shape[2] if optimized_attention.__name__ == 'attention_xformers': - if 'xops' not in globals(): - import xformers.ops as xops q, k, v = qkv_feats.unbind(dim=1) q = q.unsqueeze(0) # [1, M, H, C] k = k.unsqueeze(0) # [1, M, H, C] v = v.unsqueeze(0) # [1, M, H, C] - out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C] + #out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C] + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) elif optimized_attention.__name__ == 'attention_flash': if 'flash_attn' not in globals(): import flash_attn out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C] + else: + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) out = out[bwd_indices] # [T, H, C] diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 1dbbc4955..484622d76 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -8,6 +8,7 @@ from comfy.ldm.trellis2.attention import ( ) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder from comfy.nested_tensor import NestedTensor +from comfy.ldm.flux.math import apply_rope, apply_rope1 class SparseGELU(nn.GELU): def forward(self, input: VarLenTensor) -> VarLenTensor: @@ -52,7 +53,6 @@ class SparseMultiHeadRMSNorm(nn.Module): x = F.normalize(x, dim=-1) * self.gamma * self.scale return x.to(x_type) -# TODO: replace with apply_rope1 class SparseRotaryPositionEmbedder(nn.Module): def __init__( self, @@ -61,7 +61,6 @@ class SparseRotaryPositionEmbedder(nn.Module): rope_freq: Tuple[float, float] = (1.0, 10000.0) ): super().__init__() - assert head_dim % 2 == 0, "Head dim must be divisible by 2" self.head_dim = head_dim self.dim = dim self.rope_freq = rope_freq @@ -69,46 +68,48 @@ class SparseRotaryPositionEmbedder(nn.Module): self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) - def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: - self.freqs = self.freqs.to(indices.device) - phases = torch.outer(indices, self.freqs) - phases = torch.polar(torch.ones_like(phases), phases) - return phases + def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor: + phases_list = [] + for i in range(self.dim): + phases_list.append(torch.outer(coords[..., i], self.freqs.to(coords.device))) - def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: - x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - x_rotated = x_complex * phases.unsqueeze(-2) - x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) - return x_embed + phases = torch.cat(phases_list, dim=-1) + + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.zeros(*phases.shape[:-1], padn, device=phases.device)], dim=-1) + + cos = torch.cos(phases) + sin = torch.sin(phases) + + f_cis_0 = torch.stack([cos, sin], dim=-1) + f_cis_1 = torch.stack([-sin, cos], dim=-1) + freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1) + + return freqs_cis + + def forward(self, q, k=None): + cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}' + freqs_cis = q.get_spatial_cache(cache_name) + + if freqs_cis is None: + coords = q.coords[..., 1:].to(torch.float32) + freqs_cis = self._get_freqs_cis(coords) + q.register_spatial_cache(cache_name, freqs_cis) + + if q.feats.ndim == 3: + f_cis = freqs_cis.unsqueeze(1) + else: + f_cis = freqs_cis - def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - q (SparseTensor): [..., N, H, D] tensor of queries - k (SparseTensor): [..., N, H, D] tensor of keys - """ - assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1" - phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}' - phases = q.get_spatial_cache(phases_cache_name) - if phases is None: - coords = q.coords[..., 1:] - phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1) - if phases.shape[-1] < self.head_dim // 2: - padn = self.head_dim // 2 - phases.shape[-1] - phases = torch.cat([phases, torch.polar( - torch.ones(*phases.shape[:-1], padn, device=phases.device), - torch.zeros(*phases.shape[:-1], padn, device=phases.device) - )], dim=-1) - q.register_spatial_cache(phases_cache_name, phases) - q_embed = q.replace(self._rotary_embedding(q.feats, phases)) if k is None: - return q_embed - k_embed = k.replace(self._rotary_embedding(k.feats, phases)) - return q_embed, k_embed + return q.replace(apply_rope1(q.feats, f_cis)) + + q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis) + return q.replace(q_feats), k.replace(k_feats) class RotaryPositionEmbedder(SparseRotaryPositionEmbedder): def forward(self, indices: torch.Tensor) -> torch.Tensor: - assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}" phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) if phases.shape[-1] < self.head_dim // 2: padn = self.head_dim // 2 - phases.shape[-1] @@ -228,9 +229,6 @@ class SparseMultiHeadAttention(nn.Module): return h class ModulatedSparseTransformerBlock(nn.Module): - """ - Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. - """ def __init__( self, channels: int, From f2c0320fe84e533ad1e0173e0eb25c934027b216 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 5 Feb 2026 17:19:57 +0200 Subject: [PATCH 09/59] fixes to vae and cumesh impl. --- comfy/ldm/trellis2/cumesh.py | 189 ++++++++++++++++++++++++++--------- comfy/ldm/trellis2/vae.py | 2 +- 2 files changed, 143 insertions(+), 48 deletions(-) diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index 41ac35db9..fe7e80e15 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -5,6 +5,10 @@ import torch from typing import Dict, Callable NO_TRITION = False +try: + allow_tf32 = torch.cuda.is_tf32_supported +except Exception: + allow_tf32 = False try: import triton import triton.language as tl @@ -102,10 +106,13 @@ try: grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),) sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid]( input, weight, bias, neighbor, sorted_idx, output, - N, LOGN, Ci, Co, V, # + N, LOGN, Ci, Co, V, + B1=128, + B2=64, + BK=32, valid_kernel=valid_kernel, valid_kernel_seg=valid_kernel_seg, - allow_tf32=torch.cuda.is_tf32_supported(), + allow_tf32=allow_tf32, ) return output except: @@ -140,16 +147,16 @@ def build_submanifold_neighbor_map( neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long) - b = coords[:, 0] - x = coords[:, 1] - y = coords[:, 2] - z = coords[:, 3] + b = coords[:, 0].long() + x = coords[:, 1].long() + y = coords[:, 2].long() + z = coords[:, 3].long() offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device) - ox = x[:, None] - (Kw // 2) * Dw - oy = y[:, None] - (Kh // 2) * Dh - oz = z[:, None] - (Kd // 2) * Dd + ox = x - (Kw // 2) * Dw + oy = y - (Kh // 2) * Dh + oz = z - (Kd // 2) * Dd for v in range(half_V): if v == half_V - 1: @@ -158,10 +165,11 @@ def build_submanifold_neighbor_map( dx, dy, dz = offsets[v] - kx = ox[:, v] + dx - ky = oy[:, v] + dy - kz = oz[:, v] + dz + kx = ox + dx + ky = oy + dy + kz = oz + dz + # Check spatial bounds valid = ( (kx >= 0) & (kx < W) & (ky >= 0) & (ky < H) & @@ -169,22 +177,22 @@ def build_submanifold_neighbor_map( ) flat = ( - b * (W * H * D) + - kx * (H * D) + - ky * D + - kz + b[valid] * (W * H * D) + + kx[valid] * (H * D) + + ky[valid] * D + + kz[valid] ) - flat = flat[valid] - idx = torch.nonzero(valid, as_tuple=False).squeeze(1) + if flat.numel() > 0: + found = hashmap.lookup_flat(flat) + idx_in_M = torch.where(valid)[0] + neighbor[idx_in_M, v] = found - found = hashmap.lookup_flat(flat) - - neighbor[idx, v] = found - - # symmetric write - valid_found = found != INVALID - neighbor[found[valid_found], V - 1 - v] = idx[valid_found] + valid_found_mask = (found != INVALID) + if valid_found_mask.any(): + src_points = idx_in_M[valid_found_mask] + dst_points = found[valid_found_mask] + neighbor[dst_points, V - 1 - v] = src_points return neighbor @@ -461,31 +469,118 @@ class Mesh: def cpu(self): return self.to('cpu') - # TODO could be an option + # could make this into a new node def fill_holes(self, max_hole_perimeter=3e-2): - import cumesh - vertices = self.vertices.cuda() - faces = self.faces.cuda() - mesh = cumesh.CuMesh() - mesh.init(vertices, faces) - mesh.get_edges() - mesh.get_boundary_info() - if mesh.num_boundaries == 0: - return - mesh.get_vertex_edge_adjacency() - mesh.get_vertex_boundary_adjacency() - mesh.get_manifold_boundary_adjacency() - mesh.read_manifold_boundary_adjacency() - mesh.get_boundary_connected_components() - mesh.get_boundary_loops() - if mesh.num_boundary_loops == 0: - return - mesh.fill_holes(max_hole_perimeter=max_hole_perimeter) - new_vertices, new_faces = mesh.read() + device = self.vertices.device + vertices = self.vertices + faces = self.faces + + edges = torch.cat([ + faces[:, [0, 1]], + faces[:, [1, 2]], + faces[:, [2, 0]] + ], dim=0) + + edges_sorted, _ = torch.sort(edges, dim=1) + + unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True) + + boundary_mask = counts == 1 + boundary_edges_sorted = unique_edges[boundary_mask] + + if boundary_edges_sorted.shape[0] == 0: + return + max_idx = vertices.shape[0] + + _, inverse_indices, counts_packed = torch.unique( + torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1], + return_inverse=True, return_counts=True + ) + + boundary_packed_mask = counts_packed == 1 + is_boundary_edge = boundary_packed_mask[inverse_indices] + + active_boundary_edges = edges[is_boundary_edge] + + adj = {} + edges_np = active_boundary_edges.cpu().numpy() + for u, v in edges_np: + adj[u] = v + + loops = [] + visited_edges = set() + + possible_starts = list(adj.keys()) + + processed_nodes = set() + + for start_node in possible_starts: + if start_node in processed_nodes: + continue + + current_loop = [] + curr = start_node + + while curr in adj: + next_node = adj[curr] + if (curr, next_node) in visited_edges: + break + + visited_edges.add((curr, next_node)) + processed_nodes.add(curr) + current_loop.append(curr) + + curr = next_node + + if curr == start_node: + loops.append(current_loop) + break + + if len(current_loop) > len(edges_np): + break + + if not loops: + return + + new_faces = [] + + v_offset = vertices.shape[0] + + valid_new_verts = [] + + for loop_indices in loops: + if len(loop_indices) < 3: + continue + + loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) + loop_verts = vertices[loop_tensor] + + diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) + perimeter = torch.norm(diffs, dim=1).sum() + + if perimeter > max_hole_perimeter: + continue + + center = loop_verts.mean(dim=0) + valid_new_verts.append(center) + + c_idx = v_offset + v_offset += 1 + + num_v = len(loop_indices) + for i in range(num_v): + v_curr = loop_indices[i] + v_next = loop_indices[(i + 1) % num_v] + new_faces.append([v_curr, v_next, c_idx]) + + if len(valid_new_verts) > 0: + added_vertices = torch.stack(valid_new_verts, dim=0) + added_faces = torch.tensor(new_faces, dtype=torch.long, device=device) + + self.vertices = torch.cat([self.vertices, added_vertices], dim=0) + self.faces = torch.cat([self.faces, added_faces], dim=0) - self.vertices = new_vertices.to(self.device) - self.faces = new_faces.to(self.device) # TODO could be an option def simplify(self, target=1000000, verbose: bool=False, options: dict={}): diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 584fa91ae..2bbfa938c 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -208,7 +208,7 @@ class SparseResBlockC2S3d(nn.Module): self.to_subdiv = SparseLinear(channels, 8) self.updown = SparseChannel2Spatial(2) - def _forward(self, x, subdiv = None): + def forward(self, x, subdiv = None): if self.pred_subdiv: subdiv = self.to_subdiv(x) h = x.replace(self.norm1(x.feats)) From cdd7ced1e8c2353538506e66d0cd306c9b462c3a Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 6 Feb 2026 19:28:49 +0200 Subject: [PATCH 10/59] model fixes --- comfy/ldm/trellis2/model.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 484622d76..2367fc42c 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -26,10 +26,9 @@ class SparseFeedForwardNet(nn.Module): def forward(self, x: VarLenTensor) -> VarLenTensor: return self.mlp(x) -def manual_cast(tensor, dtype): - if not torch.is_autocast_enabled(): - return tensor.type(dtype) - return tensor +def manual_cast(obj, dtype): + return obj.to(dtype=dtype) + class LayerNorm32(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: x_dtype = x.dtype @@ -88,6 +87,12 @@ class SparseRotaryPositionEmbedder(nn.Module): return freqs_cis + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + def forward(self, q, k=None): cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}' freqs_cis = q.get_spatial_cache(cache_name) @@ -111,11 +116,15 @@ class SparseRotaryPositionEmbedder(nn.Module): class RotaryPositionEmbedder(SparseRotaryPositionEmbedder): def forward(self, indices: torch.Tensor) -> torch.Tensor: phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if torch.is_complex(phases): + phases = phases.to(torch.complex64) + else: + phases = phases.to(torch.float32) if phases.shape[-1] < self.head_dim // 2: padn = self.head_dim // 2 - phases.shape[-1] phases = torch.cat([phases, torch.polar( - torch.ones(*phases.shape[:-1], padn, device=phases.device), - torch.zeros(*phases.shape[:-1], padn, device=phases.device) + torch.ones(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32), + torch.zeros(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32) )], dim=-1) return phases @@ -468,7 +477,7 @@ class SLatFlowModel(nn.Module): h = self.input_layer(x) h = manual_cast(h, self.dtype) - t_emb = self.t_embedder(t) + t_emb = self.t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) t_emb = manual_cast(t_emb, self.dtype) @@ -687,9 +696,12 @@ class SparseStructureFlowModel(nn.Module): initialization: str = 'vanilla', qk_rms_norm: bool = False, qk_rms_norm_cross: bool = False, + operations=None, + device = None, **kwargs ): super().__init__() + self.device = device self.resolution = resolution self.in_channels = in_channels self.model_channels = model_channels @@ -706,7 +718,7 @@ class SparseStructureFlowModel(nn.Module): self.qk_rms_norm_cross = qk_rms_norm_cross self.dtype = dtype - self.t_embedder = TimestepEmbedder(model_channels) + self.t_embedder = TimestepEmbedder(model_channels, operations=operations) if share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), @@ -743,9 +755,6 @@ class SparseStructureFlowModel(nn.Module): self.out_layer = nn.Linear(model_channels, out_channels) - self.initialize_weights() - self.convert_to(self.dtype) - def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" @@ -755,7 +764,7 @@ class SparseStructureFlowModel(nn.Module): h = self.input_layer(h) if self.pe_mode == "ape": h = h + self.pos_emb[None] - t_emb = self.t_embedder(t) + t_emb = self.t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) t_emb = manual_cast(t_emb, self.dtype) @@ -799,7 +808,6 @@ class Trellis2(nn.Module): self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) args.pop("out_channels") - args.pop("in_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) def forward(self, x: NestedTensor, timestep, context, **kwargs): From 64a52f5585d46a3e804567640da2a9627f48257e Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 6 Feb 2026 20:35:33 +0200 Subject: [PATCH 11/59] checkpoint --- comfy/ldm/trellis2/model.py | 4 +++- comfy/ldm/trellis2/vae.py | 23 ++++++++++++----------- comfy/model_detection.py | 14 ++++++++++++++ comfy/sd.py | 7 +++++++ 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 2367fc42c..8ca112b13 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -797,6 +797,7 @@ class Trellis2(nn.Module): share_mod = True, qk_rms_norm = True, qk_rms_norm_cross = True, + init_txt_model=False, # for now dtype=None, device=None, operations=None): super().__init__() @@ -806,7 +807,8 @@ class Trellis2(nn.Module): "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations } self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) - self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) + if init_txt_model: + self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) args.pop("out_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 2bbfa938c..d997bbc41 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -1481,17 +1481,18 @@ class Vae(nn.Module): def __init__(self, config, operations=None): super().__init__() operations = operations or torch.nn - - self.txt_dec = SparseUnetVaeDecoder( - out_channels=6, - model_channels=[1024, 512, 256, 128, 64], - latent_channels=32, - num_blocks=[4, 16, 8, 4, 0], - block_type=["SparseConvNeXtBlock3d"] * 5, - up_block_type=["SparseResBlockC2S3d"] * 4, - block_args=[{}, {}, {}, {}, {}], - pred_subdiv=False - ) + init_txt_model = config.get("init_txt_model", False) + if init_txt_model: + self.txt_dec = SparseUnetVaeDecoder( + out_channels=6, + model_channels=[1024, 512, 256, 128, 64], + latent_channels=32, + num_blocks=[4, 16, 8, 4, 0], + block_type=["SparseConvNeXtBlock3d"] * 5, + up_block_type=["SparseResBlockC2S3d"] * 4, + block_args=[{}, {}, {}, {}, {}], + pred_subdiv=False + ) self.shape_dec = FlexiDualGridVaeDecoder( resolution=256, diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8cea16e50..4f5542af5 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -106,6 +106,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] return unet_config + if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: + unet_config = {} + unet_config["image_model"] = "trellis2" + if '{}model.shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: + unet_config["init_txt_model"] = True + else: + unet_config["init_txt_model"] = False + if metadata is not None: + if metadata["is_512"] is True: + unet_config["resolution"] = 32 + else: + unet_config["resolution"] = 64 + return unet_config + if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit unet_config = {} unet_config["audio_model"] = "dit1.0" diff --git a/comfy/sd.py b/comfy/sd.py index fd0ac85e7..be3d1b4f0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae +import comfy.ldm.trellis2.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline @@ -492,6 +493,12 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 + elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd: # trellis2 + if "txt_dec.blocks.1.16.norm1.weight" in sd: + config["init_txt_model"] = True + else: + config["init_txt_model"] = False + self.first_stage_model = comfy.ldm.trellis2.vae.Vae(config) elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} From 955c00ee38356fd35d4c347f617ef15cb10659ec Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 6 Feb 2026 23:54:27 +0200 Subject: [PATCH 12/59] post-process node --- comfy/ldm/trellis2/cumesh.py | 127 ---------------------- comfy/ldm/trellis2/vae.py | 22 ---- comfy_extras/nodes_trellis2.py | 187 ++++++++++++++++++++++++++++++++- 3 files changed, 186 insertions(+), 150 deletions(-) diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index fe7e80e15..972fb13c3 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -469,133 +469,6 @@ class Mesh: def cpu(self): return self.to('cpu') - # could make this into a new node - def fill_holes(self, max_hole_perimeter=3e-2): - - device = self.vertices.device - vertices = self.vertices - faces = self.faces - - edges = torch.cat([ - faces[:, [0, 1]], - faces[:, [1, 2]], - faces[:, [2, 0]] - ], dim=0) - - edges_sorted, _ = torch.sort(edges, dim=1) - - unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True) - - boundary_mask = counts == 1 - boundary_edges_sorted = unique_edges[boundary_mask] - - if boundary_edges_sorted.shape[0] == 0: - return - max_idx = vertices.shape[0] - - _, inverse_indices, counts_packed = torch.unique( - torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1], - return_inverse=True, return_counts=True - ) - - boundary_packed_mask = counts_packed == 1 - is_boundary_edge = boundary_packed_mask[inverse_indices] - - active_boundary_edges = edges[is_boundary_edge] - - adj = {} - edges_np = active_boundary_edges.cpu().numpy() - for u, v in edges_np: - adj[u] = v - - loops = [] - visited_edges = set() - - possible_starts = list(adj.keys()) - - processed_nodes = set() - - for start_node in possible_starts: - if start_node in processed_nodes: - continue - - current_loop = [] - curr = start_node - - while curr in adj: - next_node = adj[curr] - if (curr, next_node) in visited_edges: - break - - visited_edges.add((curr, next_node)) - processed_nodes.add(curr) - current_loop.append(curr) - - curr = next_node - - if curr == start_node: - loops.append(current_loop) - break - - if len(current_loop) > len(edges_np): - break - - if not loops: - return - - new_faces = [] - - v_offset = vertices.shape[0] - - valid_new_verts = [] - - for loop_indices in loops: - if len(loop_indices) < 3: - continue - - loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) - loop_verts = vertices[loop_tensor] - - diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) - perimeter = torch.norm(diffs, dim=1).sum() - - if perimeter > max_hole_perimeter: - continue - - center = loop_verts.mean(dim=0) - valid_new_verts.append(center) - - c_idx = v_offset - v_offset += 1 - - num_v = len(loop_indices) - for i in range(num_v): - v_curr = loop_indices[i] - v_next = loop_indices[(i + 1) % num_v] - new_faces.append([v_curr, v_next, c_idx]) - - if len(valid_new_verts) > 0: - added_vertices = torch.stack(valid_new_verts, dim=0) - added_faces = torch.tensor(new_faces, dtype=torch.long, device=device) - - self.vertices = torch.cat([self.vertices, added_vertices], dim=0) - self.faces = torch.cat([self.faces, added_faces], dim=0) - - - # TODO could be an option - def simplify(self, target=1000000, verbose: bool=False, options: dict={}): - import cumesh - vertices = self.vertices.cuda() - faces = self.faces.cuda() - - mesh = cumesh.CuMesh() - mesh.init(vertices, faces) - mesh.simplify(target, verbose=verbose, options=options) - new_vertices, new_faces = mesh.read() - - self.vertices = new_vertices.to(self.device) - self.faces = new_faces.to(self.device) - class MeshWithVoxel(Mesh, Voxel): def __init__(self, vertices: torch.Tensor, diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index d997bbc41..1d26986cc 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -231,27 +231,6 @@ class config: CONV = "flexgemm" FLEX_GEMM_HASHMAP_RATIO = 2.0 -# TODO post processing -def simplify(self, target_num_faces: int, verbose: bool=False, options: dict={}): - - num_face = self.cu_mesh.num_faces() - if num_face <= target_num_faces: - return - - thresh = options.get('thresh', 1e-8) - lambda_edge_length = options.get('lambda_edge_length', 1e-2) - lambda_skinny = options.get('lambda_skinny', 1e-3) - while True: - new_num_vert, new_num_face = self.cu_mesh.simplify_step(lambda_edge_length, lambda_skinny, thresh, False) - - if new_num_face <= target_num_faces: - break - - del_num_face = num_face - new_num_face - if del_num_face / num_face < 1e-2: - thresh *= 10 - num_face = new_num_face - class VarLenTensor: def __init__(self, feats: torch.Tensor, layout: List[slice]=None): @@ -1530,7 +1509,6 @@ class Vae(nn.Module): tex_voxels = self.decode_tex_slat(tex_slat, subs) out_mesh = [] for m, v in zip(meshes, tex_voxels): - m.fill_holes() # TODO out_mesh.append( MeshWithVoxel( m.vertices, m.faces, diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 4a36e2fee..8497b83e2 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -281,6 +281,190 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) +def simplify_fn(vertices, faces, target=100000): + + if vertices.shape[0] <= target: + return + + min_feat = vertices.min(dim=0)[0] + max_feat = vertices.max(dim=0)[0] + extent = (max_feat - min_feat).max() + + grid_resolution = int(torch.sqrt(torch.tensor(target)).item() * 1.5) + voxel_size = extent / grid_resolution + + quantized_coords = ((vertices - min_feat) / voxel_size).long() + + unique_coords, inverse_indices = torch.unique(quantized_coords, dim=0, return_inverse=True) + + num_new_verts = unique_coords.shape[0] + new_vertices = torch.zeros((num_new_verts, 3), dtype=vertices.dtype, device=vertices.device) + + counts = torch.zeros((num_new_verts, 1), dtype=vertices.dtype, device=vertices.device) + + new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices) + counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1])) + + new_vertices = new_vertices / counts.clamp(min=1) + + new_faces = inverse_indices[faces] + + v0 = new_faces[:, 0] + v1 = new_faces[:, 1] + v2 = new_faces[:, 2] + + valid_mask = (v0 != v1) & (v1 != v2) & (v2 != v0) + new_faces = new_faces[valid_mask] + + unique_face_indices, inv_face = torch.unique(new_faces.reshape(-1), return_inverse=True) + final_vertices = new_vertices[unique_face_indices] + final_faces = inv_face.reshape(-1, 3) + + return final_vertices, final_faces + +def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): + + device = vertices.device + orig_vertices = vertices + orig_faces = faces + + edges = torch.cat([ + faces[:, [0, 1]], + faces[:, [1, 2]], + faces[:, [2, 0]] + ], dim=0) + + edges_sorted, _ = torch.sort(edges, dim=1) + + unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True) + + boundary_mask = counts == 1 + boundary_edges_sorted = unique_edges[boundary_mask] + + if boundary_edges_sorted.shape[0] == 0: + return + max_idx = vertices.shape[0] + + _, inverse_indices, counts_packed = torch.unique( + torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1], + return_inverse=True, return_counts=True + ) + + boundary_packed_mask = counts_packed == 1 + is_boundary_edge = boundary_packed_mask[inverse_indices] + + active_boundary_edges = edges[is_boundary_edge] + + adj = {} + edges_np = active_boundary_edges.cpu().numpy() + for u, v in edges_np: + adj[u] = v + + loops = [] + visited_edges = set() + + possible_starts = list(adj.keys()) + + processed_nodes = set() + + for start_node in possible_starts: + if start_node in processed_nodes: + continue + + current_loop = [] + curr = start_node + + while curr in adj: + next_node = adj[curr] + if (curr, next_node) in visited_edges: + break + + visited_edges.add((curr, next_node)) + processed_nodes.add(curr) + current_loop.append(curr) + + curr = next_node + + if curr == start_node: + loops.append(current_loop) + break + + if len(current_loop) > len(edges_np): + break + + if not loops: + return + + new_faces = [] + + v_offset = vertices.shape[0] + + valid_new_verts = [] + + for loop_indices in loops: + if len(loop_indices) < 3: + continue + + loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) + loop_verts = vertices[loop_tensor] + + diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) + perimeter = torch.norm(diffs, dim=1).sum() + + if perimeter > max_hole_perimeter: + continue + + center = loop_verts.mean(dim=0) + valid_new_verts.append(center) + + c_idx = v_offset + v_offset += 1 + + num_v = len(loop_indices) + for i in range(num_v): + v_curr = loop_indices[i] + v_next = loop_indices[(i + 1) % num_v] + new_faces.append([v_curr, v_next, c_idx]) + + if len(valid_new_verts) > 0: + added_vertices = torch.stack(valid_new_verts, dim=0) + added_faces = torch.tensor(new_faces, dtype=torch.long, device=device) + + vertices_f = torch.cat([orig_vertices, added_vertices], dim=0) + faces_f = torch.cat([orig_faces, added_faces], dim=0) + + return vertices_f, faces_f + +class PostProcessMesh(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PostProcessMesh", + category="latent/3d", + inputs=[ + IO.Mesh.Input("mesh"), + IO.Int.Input("simplify", default=100_000, min=0), # max? + IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0) + ], + outputs=[ + IO.Mesh.Output("output_mesh"), + ] + ) + @classmethod + def execute(cls, mesh, simplify, fill_holes_perimeter): + verts, faces = mesh.vertices, mesh.faces + + if fill_holes_perimeter != 0.0: + verts, faces = fill_holes_fn(verts, faces, max_hole_perimeter=fill_holes_perimeter) + + if simplify != 0: + verts, faces = simplify_fn(verts, faces, simplify) + + + mesh.vertices = verts + mesh.faces = faces + + return mesh class Trellis2Extension(ComfyExtension): @override @@ -292,7 +476,8 @@ class Trellis2Extension(ComfyExtension): EmptyTextureLatentTrellis2, VaeDecodeTextureTrellis, VaeDecodeShapeTrellis, - VaeDecodeStructureTrellis2 + VaeDecodeStructureTrellis2, + PostProcessMesh ] From 704e1b54621a78b30d529b5367a2951383bd890c Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 9 Feb 2026 00:41:01 +0200 Subject: [PATCH 13/59] small bug fixes --- comfy/ldm/trellis2/attention.py | 2 +- comfy/ldm/trellis2/model.py | 5 ++++- comfy/ldm/trellis2/vae.py | 5 ++--- comfy/model_detection.py | 12 +++++++----- comfy/sd.py | 7 +++---- comfy_extras/nodes_trellis2.py | 8 ++++---- 6 files changed, 21 insertions(+), 18 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index edc85ce83..3038f4023 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -2,7 +2,7 @@ import torch import math from comfy.ldm.modules.attention import optimized_attention from typing import Tuple, Union, List -from vae import VarLenTensor +from comfy.ldm.trellis2.vae import VarLenTensor FLASH_ATTN_3_AVA = True try: diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 8ca112b13..9aab045c7 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -798,9 +798,12 @@ class Trellis2(nn.Module): qk_rms_norm = True, qk_rms_norm_cross = True, init_txt_model=False, # for now - dtype=None, device=None, operations=None): + dtype=None, device=None, operations=None, **kwargs): super().__init__() + self.dtype = dtype + # for some reason it passes num_heads = -1 + num_heads = 12 args = { "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 1d26986cc..6e13afd8d 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -6,7 +6,7 @@ from fractions import Fraction import torch.nn.functional as F from dataclasses import dataclass import numpy as np -from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d +from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: @@ -1457,10 +1457,9 @@ class SparseStructureDecoder(nn.Module): return h class Vae(nn.Module): - def __init__(self, config, operations=None): + def __init__(self, init_txt_model, operations=None): super().__init__() operations = operations or torch.nn - init_txt_model = config.get("init_txt_model", False) if init_txt_model: self.txt_dec = SparseUnetVaeDecoder( out_channels=6, diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 4f5542af5..004adbf71 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -109,15 +109,17 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: unet_config = {} unet_config["image_model"] = "trellis2" + + unet_config["init_txt_model"] = False if '{}model.shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: unet_config["init_txt_model"] = True - else: - unet_config["init_txt_model"] = False + + unet_config["resolution"] = 64 if metadata is not None: - if metadata["is_512"] is True: + if "is_512" in metadata and metadata["metadata"]: unet_config["resolution"] = 32 - else: - unet_config["resolution"] = 64 + + unet_config["num_heads"] = 12 return unet_config if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit diff --git a/comfy/sd.py b/comfy/sd.py index be3d1b4f0..25fd3ba7b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -494,11 +494,10 @@ class VAE: self.downscale_ratio = 32 self.latent_channels = 16 elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd: # trellis2 + init_txt_model = False if "txt_dec.blocks.1.16.norm1.weight" in sd: - config["init_txt_model"] = True - else: - config["init_txt_model"] = False - self.first_stage_model = comfy.ldm.trellis2.vae.Vae(config) + init_txt_model = True + self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model) elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 8497b83e2..17ba94ec8 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -198,7 +198,7 @@ class Trellis2Conditioning(IO.ComfyNode): inputs=[ IO.ClipVision.Input("clip_vision_model"), IO.Image.Input("image"), - IO.MultiCombo.Input("background_color", options=["black", "gray", "white"], default="black") + IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black") ], outputs=[ IO.Conditioning.Output(display_name="positive"), @@ -219,7 +219,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( - node_id="EmptyLatentTrellis2", + node_id="EmptyShapeLatentTrellis2", category="latent/3d", inputs=[ IO.Mesh.Input("structure_output"), @@ -242,7 +242,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( - node_id="EmptyLatentTrellis2", + node_id="EmptyTextureLatentTrellis2", category="latent/3d", inputs=[ IO.Mesh.Input("structure_output"), @@ -264,7 +264,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( - node_id="EmptyLatentTrellis2", + node_id="EmptyStructureLatentTrellis2", category="latent/3d", inputs=[ IO.Int.Input("resolution", default=3072, min=1, max=8192), From 2eef826def23599d395bb1854d6b95cb657c8385 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 9 Feb 2026 22:47:50 +0200 Subject: [PATCH 14/59] multiple fixes --- comfy/clip_vision.py | 4 +++ comfy/image_encoders/dino3.py | 32 +++++++++++++++++------ comfy/image_encoders/dino3_large.json | 11 ++++---- comfy/supported_models.py | 6 +++++ comfy_extras/nodes_trellis2.py | 37 +++++++++++++++------------ 5 files changed, 60 insertions(+), 30 deletions(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 1691fca81..71f2200b7 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -9,6 +9,7 @@ import comfy.model_management import comfy.utils import comfy.clip_model import comfy.image_encoders.dino2 +import comfy.image_encoders.dino3 class Output: def __getitem__(self, key): @@ -23,6 +24,7 @@ IMAGE_ENCODERS = { "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection, "dinov2": comfy.image_encoders.dino2.Dinov2Model, + "dinov3": comfy.image_encoders.dino3.DINOv3ViTModel } class ClipVisionModel(): @@ -134,6 +136,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json") elif 'encoder.layer.23.layer_scale2.lambda1' in sd: json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json") + elif 'layer.9.attention.o_proj.bias' in sd: # dinov3 + json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino3_large.json") else: return None diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index d07c2c5b8..b27b95b5f 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -4,7 +4,19 @@ import torch.nn as nn from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.flux.math import apply_rope -from dino2 import Dinov2MLP as DINOv3ViTMLP, LayerScale as DINOv3ViTLayerScale +from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale + +class DINOv3ViTMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype) + self.act_fn = torch.nn.GELU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) class DINOv3ViTAttention(nn.Module): def __init__(self, hidden_size, num_attention_heads, device, dtype, operations): @@ -90,6 +102,7 @@ class DINOv3ViTRopePositionEmbedding(nn.Module): self.head_dim = hidden_size // num_attention_heads self.num_patches_h = image_size // patch_size self.num_patches_w = image_size // patch_size + self.patch_size = patch_size inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device) self.register_buffer("inv_freq", inv_freq, persistent=False) @@ -106,6 +119,7 @@ class DINOv3ViTRopePositionEmbedding(nn.Module): num_patches_h, num_patches_w, dtype=torch.float32, device=device ) + self.inv_freq = self.inv_freq.to(device) angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] angles = angles.flatten(1, 2) angles = angles.tile(2) @@ -140,27 +154,30 @@ class DINOv3ViTEmbeddings(nn.Module): cls_token = self.cls_token.expand(batch_size, -1, -1) register_tokens = self.register_tokens.expand(batch_size, -1, -1) + device = patch_embeddings + cls_token = cls_token.to(device) + register_tokens = register_tokens.to(device) embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) return embeddings class DINOv3ViTLayer(nn.Module): - def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, layerscale_value, mlp_bias, intermediate_size, num_attention_heads, + def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size, num_attention_heads, device, dtype, operations): super().__init__() self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps) self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations) - self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, layerscale_value, device=device, dtype=dtype) + self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) if use_gated_mlp: self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations) else: - self.mlp = DINOv3ViTMLP(hidden_size, device=device, dtype=dtype, operations=operations) - self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, layerscale_value, device=device, dtype=dtype) + self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations) + self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) def forward( self, @@ -188,7 +205,7 @@ class DINOv3ViTLayer(nn.Module): class DINOv3ViTModel(nn.Module): - def __init__(self, config, device, dtype, operations): + def __init__(self, config, dtype, device, operations): super().__init__() num_hidden_layers = config["num_hidden_layers"] hidden_size = config["hidden_size"] @@ -196,7 +213,6 @@ class DINOv3ViTModel(nn.Module): num_register_tokens = config["num_register_tokens"] intermediate_size = config["intermediate_size"] layer_norm_eps = config["layer_norm_eps"] - layerscale_value = config["layerscale_value"] num_channels = config["num_channels"] patch_size = config["patch_size"] rope_theta = config["rope_theta"] @@ -208,7 +224,7 @@ class DINOv3ViTModel(nn.Module): rope_theta, hidden_size, num_attention_heads, image_size=512, patch_size=patch_size, dtype=dtype, device=device ) self.layer = nn.ModuleList( - [DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=False, layerscale_value=layerscale_value, mlp_bias=True, + [DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=False, mlp_bias=True, intermediate_size=intermediate_size,num_attention_heads = num_attention_heads, dtype=dtype, device=device, operations=operations) for _ in range(num_hidden_layers)]) diff --git a/comfy/image_encoders/dino3_large.json b/comfy/image_encoders/dino3_large.json index 96263f0d6..53f761a25 100644 --- a/comfy/image_encoders/dino3_large.json +++ b/comfy/image_encoders/dino3_large.json @@ -1,16 +1,15 @@ { - - "hidden_size": 384, + "model_type": "dinov3", + "hidden_size": 1024, "image_size": 224, "initializer_range": 0.02, - "intermediate_size": 1536, + "intermediate_size": 4096, "key_bias": false, "layer_norm_eps": 1e-05, - "layerscale_value": 1.0, "mlp_bias": true, - "num_attention_heads": 6, + "num_attention_heads": 16, "num_channels": 3, - "num_hidden_layers": 12, + "num_hidden_layers": 24, "num_register_tokens": 4, "patch_size": 16, "pos_embed_rescale": 2.0, diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 9e2f17149..3373f78a2 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1251,12 +1251,18 @@ class Trellis2(supported_models_base.BASE): "shift": 3.0, } + memory_usage_factor = 3.5 + latent_format = latent_formats.Trellis2 vae_key_prefix = ["vae."] + clip_vision_prefix = "conditioner.main_image_encoder.model." def get_model(self, state_dict, prefix="", device=None): return model_base.Trellis2(self, device=device) + def clip_target(self, state_dict={}): + return None + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 17ba94ec8..f53d36736 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -3,10 +3,8 @@ from comfy_api.latest import ComfyExtension, IO import torch from comfy.ldm.trellis2.model import SparseTensor import comfy.model_management -from PIL import Image -import PIL -import numpy as np from comfy.nested_tensor import NestedTensor +from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode shape_slat_normalization = { "mean": torch.tensor([ @@ -76,23 +74,30 @@ def run_conditioning( # Convert image to PIL if image.dim() == 4: - pil_image = (image[0] * 255).clip(0, 255).astype(torch.uint8) + pil_image = (image[0] * 255).clip(0, 255).to(torch.uint8) else: - pil_image = (image * 255).clip(0, 255).astype(torch.uint8) + pil_image = (image * 255).clip(0, 255).to(torch.uint8) + pil_image = pil_image.movedim(-1, 0) pil_image = smart_crop_square(pil_image, background_color=bg_color) model.image_size = 512 def set_image_size(image, image_size=512): - image = PIL.from_array(image) - image = [i.resize((image_size, image_size), Image.LANCZOS) for i in image] - image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] - image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] - image = torch.stack(image).to(torch_device) - return image + if image.ndim == 3: + image = image.unsqueeze(0) - pil_image = set_image_size(image, 512) - cond_512 = model([pil_image]) + to_pil = ToPILImage() + to_tensor = ToTensor() + resizer = Resize((image_size, image_size), interpolation=InterpolationMode.LANCZOS) + + pil_img = to_pil(image.squeeze(0)) + resized_pil = resizer(pil_img) + image = to_tensor(resized_pil).unsqueeze(0) + + return image.to(torch_device).float() + + pil_image = set_image_size(pil_image, 512) + cond_512 = model(pil_image) cond_1024 = None if include_1024: @@ -267,7 +272,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): node_id="EmptyStructureLatentTrellis2", category="latent/3d", inputs=[ - IO.Int.Input("resolution", default=3072, min=1, max=8192), + IO.Int.Input("resolution", default=256, min=1, max=8192), IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), ], outputs=[ @@ -275,9 +280,9 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): ] ) @classmethod - def execute(cls, res, batch_size): + def execute(cls, resolution, batch_size): in_channels = 32 - latent = torch.randn(batch_size, in_channels, res, res, res) + latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) From f4059c189e92f7a1a6c65014c4901f9ad67b6b25 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 11 Feb 2026 01:27:54 +0200 Subject: [PATCH 15/59] dinov3 fixes + other --- comfy/image_encoders/dino3.py | 39 +++++++++++++++++++++++++++------- comfy/ldm/trellis2/model.py | 17 +++++++++++---- comfy_extras/nodes_trellis2.py | 14 ++++++------ 3 files changed, 51 insertions(+), 19 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index b27b95b5f..ef04556da 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -24,7 +24,6 @@ class DINOv3ViTAttention(nn.Module): self.embed_dim = hidden_size self.num_heads = num_attention_heads self.head_dim = self.embed_dim // self.num_heads - self.is_causal = False self.scaling = self.head_dim**-0.5 self.is_causal = False @@ -53,18 +52,41 @@ class DINOv3ViTAttention(nn.Module): key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) - cos, sin = position_embeddings - position_embeddings = torch.stack([cos, sin], dim = -1) - query_states, key_states = apply_rope(query_states, key_states, position_embeddings) + if position_embeddings is not None: + cos, sin = position_embeddings - attn_output, attn_weights = optimized_attention_for_device( - query_states, key_states, value_states, attention_mask, skip_reshape=True, skip_output_reshape=True + num_tokens = query_states.shape[-2] + num_patches = cos.shape[-2] + num_prefix_tokens = num_tokens - num_patches + + q_prefix, q_patches = query_states.split((num_prefix_tokens, num_patches), dim=-2) + k_prefix, k_patches = key_states.split((num_prefix_tokens, num_patches), dim=-2) + + cos = cos[..., :self.head_dim // 2] + sin = sin[..., :self.head_dim // 2] + + f_cis_0 = torch.stack([cos, sin], dim=-1) + f_cis_1 = torch.stack([-sin, cos], dim=-1) + freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1) + + while freqs_cis.ndim < q_patches.ndim + 1: + freqs_cis = freqs_cis.unsqueeze(0) + + q_patches, k_patches = apply_rope(q_patches, k_patches, freqs_cis) + + query_states = torch.cat((q_prefix, q_patches), dim=-2) + key_states = torch.cat((k_prefix, k_patches), dim=-2) + + attn = optimized_attention_for_device(query_states.device, mask=False) + + attn_output = attn( + query_states, key_states, value_states, self.num_heads, attention_mask, skip_reshape=True, skip_output_reshape=True ) attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights + return attn_output class DINOv3ViTGatedMLP(nn.Module): def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations): @@ -187,7 +209,7 @@ class DINOv3ViTLayer(nn.Module): ) -> torch.Tensor: residual = hidden_states hidden_states = self.norm1(hidden_states) - hidden_states, _ = self.attention( + hidden_states = self.attention( hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings, @@ -250,6 +272,7 @@ class DINOv3ViTModel(nn.Module): position_embeddings=position_embeddings, ) + self.norm = self.norm.to(hidden_states.device) sequence_output = self.norm(hidden_states) pooled_output = sequence_output[:, 0, :] diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 9aab045c7..8bc8e8f7a 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -113,6 +113,13 @@ class SparseRotaryPositionEmbedder(nn.Module): q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis) return q.replace(q_feats), k.replace(k_feats) + @staticmethod + def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases.unsqueeze(-2) + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + class RotaryPositionEmbedder(SparseRotaryPositionEmbedder): def forward(self, indices: torch.Tensor) -> torch.Tensor: phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) @@ -559,6 +566,7 @@ class MultiHeadAttention(nn.Module): def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: B, L, C = x.shape if self._type == "self": + x = x.to(next(self.to_qkv.parameters()).dtype) qkv = self.to_qkv(x) qkv = qkv.reshape(B, L, 3, self.num_heads, -1) @@ -688,7 +696,7 @@ class SparseStructureFlowModel(nn.Module): num_heads: Optional[int] = None, num_head_channels: Optional[int] = 64, mlp_ratio: float = 4, - pe_mode: Literal["ape", "rope"] = "ape", + pe_mode: Literal["ape", "rope"] = "rope", rope_freq: Tuple[float, float] = (1.0, 10000.0), dtype: str = 'float32', use_checkpoint: bool = False, @@ -756,14 +764,14 @@ class SparseStructureFlowModel(nn.Module): self.out_layer = nn.Linear(model_channels, out_channels) def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3) assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous() + h = h.to(next(self.input_layer.parameters()).dtype) h = self.input_layer(h) - if self.pe_mode == "ape": - h = h + self.pos_emb[None] t_emb = self.t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) @@ -816,7 +824,8 @@ class Trellis2(nn.Module): self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) def forward(self, x: NestedTensor, timestep, context, **kwargs): - x = x.tensors[0] + if isinstance(x, NestedTensor): + x = x.tensors[0] embeds = kwargs.get("embeds") if not hasattr(x, "feats"): mode = "structure_generation" diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index f53d36736..4eff2dbc3 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -97,13 +97,13 @@ def run_conditioning( return image.to(torch_device).float() pil_image = set_image_size(pil_image, 512) - cond_512 = model(pil_image) + cond_512 = model(pil_image)[0] cond_1024 = None if include_1024: model.image_size = 1024 pil_image = set_image_size(pil_image, 1024) - cond_1024 = model([pil_image]) + cond_1024 = model(pil_image)[0] neg_cond = torch.zeros_like(cond_512) @@ -115,7 +115,7 @@ def run_conditioning( conditioning['cond_1024'] = cond_1024.to(device) preprocessed_tensor = pil_image.to(torch.float32) / 255.0 - preprocessed_tensor = torch.from_numpy(preprocessed_tensor).unsqueeze(0) + preprocessed_tensor = preprocessed_tensor.unsqueeze(0) return conditioning, preprocessed_tensor @@ -217,7 +217,7 @@ class Trellis2Conditioning(IO.ComfyNode): conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color) embeds = conditioning["cond_1024"] # should add that positive = [[conditioning["cond_512"], {"embeds": embeds}]] - negative = [[conditioning["cond_neg"], {"embeds": embeds}]] + negative = [[conditioning["neg_cond"], {"embeds": embeds}]] return IO.NodeOutput(positive, negative) class EmptyShapeLatentTrellis2(IO.ComfyNode): @@ -272,7 +272,6 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): node_id="EmptyStructureLatentTrellis2", category="latent/3d", inputs=[ - IO.Int.Input("resolution", default=256, min=1, max=8192), IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), ], outputs=[ @@ -280,8 +279,9 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): ] ) @classmethod - def execute(cls, resolution, batch_size): - in_channels = 32 + def execute(cls, batch_size): + in_channels = 8 + resolution = 16 latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) From b7764479c263c9d41bd077c7453b8d4c15551a34 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 11 Feb 2026 20:33:59 +0200 Subject: [PATCH 16/59] debugging --- comfy/ldm/trellis2/attention.py | 7 ++++-- comfy/ldm/trellis2/model.py | 8 +++---- comfy/ldm/trellis2/vae.py | 41 +++++++++++++++++++++++++++------ comfy/sd.py | 4 ++++ comfy_extras/nodes_trellis2.py | 30 +++++++++++++++--------- 5 files changed, 65 insertions(+), 25 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 3038f4023..e6aa50842 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -14,6 +14,7 @@ except: def scaled_dot_product_attention(*args, **kwargs): num_all_args = len(args) + len(kwargs) + q = None if num_all_args == 1: qkv = args[0] if len(args) > 0 else kwargs['qkv'] @@ -26,8 +27,10 @@ def scaled_dot_product_attention(*args, **kwargs): k = args[1] if len(args) > 1 else kwargs['k'] v = args[2] if len(args) > 2 else kwargs['v'] - # TODO verify - heads = q or qkv + if q is not None: + heads = q + else: + heads = qkv heads = heads.shape[2] if optimized_attention.__name__ == 'attention_xformers': diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 8bc8e8f7a..17286a553 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -7,7 +7,6 @@ from comfy.ldm.trellis2.attention import ( sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention ) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder -from comfy.nested_tensor import NestedTensor from comfy.ldm.flux.math import apply_rope, apply_rope1 class SparseGELU(nn.GELU): @@ -586,6 +585,7 @@ class MultiHeadAttention(nn.Module): else: Lkv = context.shape[1] q = self.to_q(x) + context = context.to(next(self.to_kv.parameters()).dtype) kv = self.to_kv(context) q = q.reshape(B, L, self.num_heads, -1) kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) @@ -782,6 +782,7 @@ class SparseStructureFlowModel(nn.Module): h = block(h, t_emb, cond, self.rope_phases) h = manual_cast(h, x.dtype) h = F.layer_norm(h, h.shape[-1:]) + h = h.to(next(self.out_layer.parameters()).dtype) h = self.out_layer(h) h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous() @@ -823,9 +824,7 @@ class Trellis2(nn.Module): args.pop("out_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) - def forward(self, x: NestedTensor, timestep, context, **kwargs): - if isinstance(x, NestedTensor): - x = x.tensors[0] + def forward(self, x, timestep, context, **kwargs): embeds = kwargs.get("embeds") if not hasattr(x, "feats"): mode = "structure_generation" @@ -843,6 +842,5 @@ class Trellis2(nn.Module): timestep = timestep_reshift(timestep) out = self.structure_model(x, timestep, context) - out = NestedTensor([out]) out.generation_mode = mode return out diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 6e13afd8d..57bf78346 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -10,9 +10,6 @@ from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_ def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: - """ - 3D pixel shuffle. - """ B, C, H, W, D = x.shape C_ = C // scale_factor**3 x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) @@ -967,6 +964,25 @@ class SparseLinear(nn.Linear): return input.replace(super().forward(input.feats)) +MIX_PRECISION_MODULES = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + SparseConv3d, + SparseLinear, +) + + +def convert_module_to_f16(l): + if isinstance(l, MIX_PRECISION_MODULES): + for p in l.parameters(): + p.data = p.data.half() + + class SparseUnetVaeEncoder(nn.Module): """ @@ -1381,8 +1397,12 @@ class ResBlock3d(nn.Module): self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: + self.norm1 = self.norm1.to(torch.float32) + self.norm2 = self.norm2.to(torch.float32) h = self.norm1(x) h = F.silu(h) + dtype = next(self.conv1.parameters()).dtype + h = h.to(dtype) h = self.conv1(h) h = self.norm2(h) h = F.silu(h) @@ -1400,7 +1420,7 @@ class SparseStructureDecoder(nn.Module): channels: List[int], num_res_blocks_middle: int = 2, norm_type = "layer", - use_fp16: bool = False, + use_fp16: bool = True, ): super().__init__() self.out_channels = out_channels @@ -1439,20 +1459,27 @@ class SparseStructureDecoder(nn.Module): if use_fp16: self.convert_to_fp16() - @property def device(self) -> torch.device: return next(self.parameters()).device + def convert_to_fp16(self) -> None: + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = next(self.input_layer.parameters()).dtype + x = x.to(dtype) h = self.input_layer(x) h = h.type(self.dtype) - h = self.middle_block(h) for block in self.blocks: h = block(h) - h = h.type(x.dtype) + h = h.to(torch.float32) + self.out_layer = self.out_layer.to(torch.float32) h = self.out_layer(h) return h diff --git a/comfy/sd.py b/comfy/sd.py index 25fd3ba7b..276e87d2a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -497,6 +497,10 @@ class VAE: init_txt_model = False if "txt_dec.blocks.1.16.norm1.weight" in sd: init_txt_model = True + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + # TODO + self.memory_used_decode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model) elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 4eff2dbc3..c735469be 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -3,7 +3,7 @@ from comfy_api.latest import ComfyExtension, IO import torch from comfy.ldm.trellis2.model import SparseTensor import comfy.model_management -from comfy.nested_tensor import NestedTensor +import comfy.model_patcher from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode shape_slat_normalization = { @@ -137,14 +137,15 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples: NestedTensor, vae, resolution): - samples = samples.tensors[0] + def execute(cls, samples, vae, resolution): + vae = vae.first_stage_model + samples = samples["samples"] std = shape_slat_normalization["std"] mean = shape_slat_normalization["mean"] samples = samples * std + mean mesh, subs = vae.decode_shape_slat(resolution, samples) - return mesh, subs + return IO.NodeOutput(mesh, subs) class VaeDecodeTextureTrellis(IO.ComfyNode): @classmethod @@ -164,13 +165,14 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, vae, shape_subs): - samples = samples.tensors[0] + vae = vae.first_stage_model + samples = samples["samples"] std = tex_slat_normalization["std"] mean = tex_slat_normalization["mean"] samples = samples * std + mean mesh = vae.decode_tex_slat(samples, shape_subs) - return mesh + return IO.NodeOutput(mesh) class VaeDecodeStructureTrellis2(IO.ComfyNode): @classmethod @@ -189,10 +191,19 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): @classmethod def execute(cls, samples, vae): + vae = vae.first_stage_model decoder = vae.struct_dec + load_device = comfy.model_management.get_torch_device() + decoder = comfy.model_patcher.ModelPatcher( + decoder, load_device=load_device, offload_device=comfy.model_management.vae_offload_device() + ) + comfy.model_management.load_model_gpu(decoder) + decoder = decoder.model + samples = samples["samples"] + samples = samples.to(load_device) decoded = decoder(samples)>0 coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int() - return coords + return IO.NodeOutput(coords) class Trellis2Conditioning(IO.ComfyNode): @classmethod @@ -240,7 +251,6 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): coords = structure_output # or structure_output.coords in_channels = 32 latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) - latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @@ -262,7 +272,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): # TODO in_channels = 32 latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1])) - latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @@ -283,7 +292,6 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): in_channels = 8 resolution = 16 latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) - latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) def simplify_fn(vertices, faces, target=100000): @@ -469,7 +477,7 @@ class PostProcessMesh(IO.ComfyNode): mesh.vertices = verts mesh.faces = faces - return mesh + return IO.NodeOutput(mesh) class Trellis2Extension(ComfyExtension): @override From 8e90bdc1ccad930527cbf3cd5170590ff3eb7902 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 12 Feb 2026 00:30:51 +0200 Subject: [PATCH 17/59] small error fixes --- comfy_extras/nodes_trellis2.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index c735469be..9fd257785 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,5 +1,5 @@ from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO +from comfy_api.latest import ComfyExtension, IO, Types import torch from comfy.ldm.trellis2.model import SparseTensor import comfy.model_management @@ -185,7 +185,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): IO.Vae.Input("vae"), ], outputs=[ - IO.Mesh.Output("structure_output"), + IO.Voxel.Output("structure_output"), ] ) @@ -194,16 +194,15 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): vae = vae.first_stage_model decoder = vae.struct_dec load_device = comfy.model_management.get_torch_device() - decoder = comfy.model_patcher.ModelPatcher( - decoder, load_device=load_device, offload_device=comfy.model_management.vae_offload_device() - ) - comfy.model_management.load_model_gpu(decoder) - decoder = decoder.model + offload_device = comfy.model_management.vae_offload_device() + decoder = decoder.to(load_device) samples = samples["samples"] samples = samples.to(load_device) decoded = decoder(samples)>0 - coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int() - return IO.NodeOutput(coords) + decoder.to(offload_device) + comfy.model_management.get_offload_stream + out = Types.VOXEL(decoded.squeeze(1).float()) + return IO.NodeOutput(out) class Trellis2Conditioning(IO.ComfyNode): @classmethod @@ -238,7 +237,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): node_id="EmptyShapeLatentTrellis2", category="latent/3d", inputs=[ - IO.Mesh.Input("structure_output"), + IO.Voxel.Input("structure_output"), ], outputs=[ IO.Latent.Output(), @@ -247,8 +246,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): - # i will see what i have to do here - coords = structure_output # or structure_output.coords + decoded = structure_output.data + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1) in_channels = 32 latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) @@ -260,7 +259,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): node_id="EmptyTextureLatentTrellis2", category="latent/3d", inputs=[ - IO.Mesh.Input("structure_output"), + IO.Voxel.Input("structure_output"), ], outputs=[ IO.Latent.Output(), @@ -271,7 +270,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): def execute(cls, structure_output): # TODO in_channels = 32 - latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1])) + latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1])) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): From 0e239dc39b878b1bc3357fadeece4c76ff335892 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 12 Feb 2026 23:35:57 +0200 Subject: [PATCH 18/59] fixed attn (couldn't use apply_rope for dino3) --- comfy/image_encoders/dino3.py | 45 +++++++++++++++++------------------ 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index ef04556da..9cb231e28 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn from comfy.ldm.modules.attention import optimized_attention_for_device -from comfy.ldm.flux.math import apply_rope from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale class DINOv3ViTMLP(nn.Module): @@ -18,6 +17,26 @@ class DINOv3ViTMLP(nn.Module): def forward(self, x): return self.down_proj(self.act_fn(self.up_proj(x))) +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) +def apply_rotary_pos_emb(q, k, cos, sin, **kwargs): + num_tokens = q.shape[-2] + num_patches = sin.shape[-2] + num_prefix_tokens = num_tokens - num_patches + + q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2) + k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2) + + q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin) + k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin) + + q = torch.cat((q_prefix_tokens, q_patches), dim=-2) + k = torch.cat((k_prefix_tokens, k_patches), dim=-2) + + return q, k + class DINOv3ViTAttention(nn.Module): def __init__(self, hidden_size, num_attention_heads, device, dtype, operations): super().__init__() @@ -54,28 +73,7 @@ class DINOv3ViTAttention(nn.Module): if position_embeddings is not None: cos, sin = position_embeddings - - num_tokens = query_states.shape[-2] - num_patches = cos.shape[-2] - num_prefix_tokens = num_tokens - num_patches - - q_prefix, q_patches = query_states.split((num_prefix_tokens, num_patches), dim=-2) - k_prefix, k_patches = key_states.split((num_prefix_tokens, num_patches), dim=-2) - - cos = cos[..., :self.head_dim // 2] - sin = sin[..., :self.head_dim // 2] - - f_cis_0 = torch.stack([cos, sin], dim=-1) - f_cis_1 = torch.stack([-sin, cos], dim=-1) - freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1) - - while freqs_cis.ndim < q_patches.ndim + 1: - freqs_cis = freqs_cis.unsqueeze(0) - - q_patches, k_patches = apply_rope(q_patches, k_patches, freqs_cis) - - query_states = torch.cat((q_prefix, q_patches), dim=-2) - key_states = torch.cat((k_prefix, k_patches), dim=-2) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) attn = optimized_attention_for_device(query_states.device, mask=False) @@ -83,6 +81,7 @@ class DINOv3ViTAttention(nn.Module): query_states, key_states, value_states, self.num_heads, attention_mask, skip_reshape=True, skip_output_reshape=True ) + attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() attn_output = self.o_proj(attn_output) From 0e51bee64ff909f5dff90a2782e4973818196624 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 13 Feb 2026 00:10:25 +0200 Subject: [PATCH 19/59] more reliable detection --- comfy/ldm/trellis2/model.py | 8 +------- comfy_extras/nodes_trellis2.py | 3 +++ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 17286a553..760372f5c 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -826,13 +826,7 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): embeds = kwargs.get("embeds") - if not hasattr(x, "feats"): - mode = "structure_generation" - else: - if x.feats.shape[1] == 32: - mode = "shape_generation" - else: - mode = "texture_generation" + mode = x.generation_mode if mode == "shape_generation": # TODO out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)])) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 9fd257785..a5c387c1d 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -250,6 +250,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1) in_channels = 32 latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) + latent.generation_mode = "shape_generation" return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @@ -271,6 +272,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): # TODO in_channels = 32 latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1])) + latent.generation_mode = "texture_generation" return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @@ -291,6 +293,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): in_channels = 8 resolution = 16 latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) + latent.generation_mode = "structure_generation" return IO.NodeOutput({"samples": latent, "type": "trellis2"}) def simplify_fn(vertices, faces, target=100000): From 92aa058a587900c24ff163e89288295f9d334acf Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 13 Feb 2026 21:05:59 +0200 Subject: [PATCH 20/59] . --- comfy/image_encoders/dino3.py | 2 ++ comfy/ldm/trellis2/model.py | 18 +++++++++++------- comfy_extras/nodes_trellis2.py | 9 +++------ 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index 9cb231e28..ce6b2edd9 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -228,6 +228,8 @@ class DINOv3ViTLayer(nn.Module): class DINOv3ViTModel(nn.Module): def __init__(self, config, dtype, device, operations): super().__init__() + if dtype == torch.float16: + dtype = torch.bfloat16 num_hidden_layers = config["num_hidden_layers"] hidden_size = config["hidden_size"] num_attention_heads = config["num_attention_heads"] diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 760372f5c..76fe8ad19 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -678,10 +678,7 @@ class ModulatedTransformerCrossBlock(nn.Module): return x def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: - if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False) - else: - return self._forward(x, mod, context, phases) + return self._forward(x, mod, context, phases) class SparseStructureFlowModel(nn.Module): @@ -823,18 +820,25 @@ class Trellis2(nn.Module): self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) args.pop("out_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) + self.guidance_interval = [0.6, 1.0] + self.guidance_interval_txt = [0.6, 0.9] def forward(self, x, timestep, context, **kwargs): embeds = kwargs.get("embeds") - mode = x.generation_mode + mode = kwargs.get("generation_mode") + sigmas = kwargs.get("sigmas")[0].item() + cond = context.chunk(2) + shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] + txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] + if mode == "shape_generation": # TODO out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)])) elif mode == "texture_generation": - out = self.shape2txt(x, timestep, context) + out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure timestep = timestep_reshift(timestep) - out = self.structure_model(x, timestep, context) + out = self.structure_model(x, timestep, context if not shape_rule else cond) out.generation_mode = mode return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index a5c387c1d..560751091 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -250,8 +250,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1) in_channels = 32 latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) - latent.generation_mode = "shape_generation" - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation"}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -272,8 +271,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): # TODO in_channels = 32 latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1])) - latent.generation_mode = "texture_generation" - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod @@ -293,8 +291,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): in_channels = 8 resolution = 16 latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) - latent.generation_mode = "structure_generation" - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "structure_generation"}) def simplify_fn(vertices, faces, target=100000): From 91fa563b21a745041c50bb7f7e5038330e01ae38 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 16 Feb 2026 01:53:53 +0200 Subject: [PATCH 21/59] rewriting conditioning logic + model code addition --- comfy/ldm/trellis2/model.py | 10 ++- comfy_extras/nodes_trellis2.py | 125 ++++++++++++++++----------------- 2 files changed, 70 insertions(+), 65 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 76fe8ad19..4c398294a 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -826,8 +826,11 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): embeds = kwargs.get("embeds") mode = kwargs.get("generation_mode") - sigmas = kwargs.get("sigmas")[0].item() - cond = context.chunk(2) + transformer_options = kwargs.get("transformer_options") + sigmas = transformer_options.get("sigmas")[0].item() + if sigmas < 1.00001: + timestep *= 1000.0 + cond = context.chunk(2)[1] shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] @@ -838,6 +841,9 @@ class Trellis2(nn.Module): out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure timestep = timestep_reshift(timestep) + if shape_rule: + x = x[0].unsqueeze(0) + timestep = timestep[0] out = self.structure_model(x, timestep, context if not shape_rule else cond) out.generation_mode = mode diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 560751091..4d97129eb 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -4,7 +4,6 @@ import torch from comfy.ldm.trellis2.model import SparseTensor import comfy.model_management import comfy.model_patcher -from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode shape_slat_normalization = { "mean": torch.tensor([ @@ -36,86 +35,85 @@ tex_slat_normalization = { ])[None] } -def smart_crop_square( - image: torch.Tensor, - background_color=(128, 128, 128), -): +dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) +dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + +def smart_crop_square(image, mask, margin_ratio=0.1, bg_color=(128, 128, 128)): + nz = torch.nonzero(mask[0] > 0.5) + if nz.shape[0] == 0: + C, H, W = image.shape + side = max(H, W) + canvas = torch.full((C, side, side), 0.5, device=image.device) # Gray + canvas[:, (side-H)//2:(side-H)//2+H, (side-W)//2:(side-W)//2+W] = image + return canvas + + y_min, x_min = nz.min(dim=0)[0] + y_max, x_max = nz.max(dim=0)[0] + + obj_w, obj_h = x_max - x_min, y_max - y_min + center_x, center_y = (x_min + x_max) / 2, (y_min + y_max) / 2 + + side = int(max(obj_w, obj_h) * (1 + margin_ratio * 2)) + half_side = side / 2 + + x1, y1 = int(center_x - half_side), int(center_y - half_side) + x2, y2 = x1 + side, y1 + side + C, H, W = image.shape - size = max(H, W) - canvas = torch.empty( - (C, size, size), - dtype=image.dtype, - device=image.device - ) + canvas = torch.ones((C, side, side), device=image.device) for c in range(C): - canvas[c].fill_(background_color[c]) - top = (size - H) // 2 - left = (size - W) // 2 - canvas[:, top:top + H, left:left + W] = image + canvas[c] *= (bg_color[c] / 255.0) + + src_x1, src_y1 = max(0, x1), max(0, y1) + src_x2, src_y2 = min(W, x2), min(H, y2) + + dst_x1, dst_y1 = max(0, -x1), max(0, -y1) + dst_x2 = dst_x1 + (src_x2 - src_x1) + dst_y2 = dst_y1 + (src_y2 - src_y1) + + canvas[:, dst_y1:dst_y2, dst_x1:dst_x2] = image[:, src_y1:src_y2, src_x1:src_x2] return canvas -def run_conditioning( - model, - image: torch.Tensor, - include_1024: bool = True, - background_color: str = "black", -): - # TODO: should check if normalization was applied in these steps - model = model.model - device = comfy.model_management.intermediate_device() # replaces .cpu() - torch_device = comfy.model_management.get_torch_device() # replaces .cuda() - bg_colors = { - "black": (0, 0, 0), - "gray": (128, 128, 128), - "white": (255, 255, 255), - } - bg_color = bg_colors.get(background_color, (128, 128, 128)) +def run_conditioning(model, image, mask, include_1024 = True, background_color = "black"): + model_internal = model.model + device = comfy.model_management.intermediate_device() + torch_device = comfy.model_management.get_torch_device() - # Convert image to PIL - if image.dim() == 4: - pil_image = (image[0] * 255).clip(0, 255).to(torch.uint8) - else: - pil_image = (image * 255).clip(0, 255).to(torch.uint8) + bg_colors = {"black": (0, 0, 0), "gray": (128, 128, 128), "white": (255, 255, 255)} + bg_rgb = bg_colors.get(background_color, (128, 128, 128)) - pil_image = pil_image.movedim(-1, 0) - pil_image = smart_crop_square(pil_image, background_color=bg_color) + img_t = image[0].movedim(-1, 0).to(torch_device).float() + mask_t = mask[0].to(torch_device).float() + if mask_t.ndim == 2: + mask_t = mask_t.unsqueeze(0) - model.image_size = 512 - def set_image_size(image, image_size=512): - if image.ndim == 3: - image = image.unsqueeze(0) + cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb) - to_pil = ToPILImage() - to_tensor = ToTensor() - resizer = Resize((image_size, image_size), interpolation=InterpolationMode.LANCZOS) + def prepare_tensor(img, size): + resized = torch.nn.functional.interpolate( + img.unsqueeze(0), size=(size, size), mode='bicubic', align_corners=False + ) + return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) - pil_img = to_pil(image.squeeze(0)) - resized_pil = resizer(pil_img) - image = to_tensor(resized_pil).unsqueeze(0) - - return image.to(torch_device).float() - - pil_image = set_image_size(pil_image, 512) - cond_512 = model(pil_image)[0] + model_internal.image_size = 512 + input_512 = prepare_tensor(cropped_img, 512) + cond_512 = model_internal(input_512)[0] cond_1024 = None if include_1024: - model.image_size = 1024 - pil_image = set_image_size(pil_image, 1024) - cond_1024 = model(pil_image)[0] - - neg_cond = torch.zeros_like(cond_512) + model_internal.image_size = 1024 + input_1024 = prepare_tensor(cropped_img, 1024) + cond_1024 = model_internal(input_1024)[0] conditioning = { 'cond_512': cond_512.to(device), - 'neg_cond': neg_cond.to(device), + 'neg_cond': torch.zeros_like(cond_512).to(device), } if cond_1024 is not None: conditioning['cond_1024'] = cond_1024.to(device) - preprocessed_tensor = pil_image.to(torch.float32) / 255.0 - preprocessed_tensor = preprocessed_tensor.unsqueeze(0) + preprocessed_tensor = cropped_img.movedim(0, -1).unsqueeze(0).cpu() return conditioning, preprocessed_tensor @@ -213,6 +211,7 @@ class Trellis2Conditioning(IO.ComfyNode): inputs=[ IO.ClipVision.Input("clip_vision_model"), IO.Image.Input("image"), + IO.Mask.Input("mask"), IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black") ], outputs=[ @@ -222,9 +221,9 @@ class Trellis2Conditioning(IO.ComfyNode): ) @classmethod - def execute(cls, clip_vision_model, image, background_color) -> IO.NodeOutput: + def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput: # could make 1024 an option - conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color) + conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) embeds = conditioning["cond_1024"] # should add that positive = [[conditioning["cond_512"], {"embeds": embeds}]] negative = [[conditioning["neg_cond"], {"embeds": embeds}]] From c14317d6e0d574253fcda8184470683d4ea9ded6 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 17 Feb 2026 00:10:48 +0200 Subject: [PATCH 22/59] postprocessing node fixes + model small fixes --- comfy/ldm/trellis2/model.py | 7 ++- comfy_extras/nodes_trellis2.py | 87 ++++++++++++++++------------------ 2 files changed, 47 insertions(+), 47 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 4c398294a..eb410fe8b 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -787,8 +787,10 @@ class SparseStructureFlowModel(nn.Module): return h def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0): + t_shifted /= 1000.0 t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1)) t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear) + t_new *= 1000.0 return t_new class Trellis2(nn.Module): @@ -841,10 +843,13 @@ class Trellis2(nn.Module): out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure timestep = timestep_reshift(timestep) + orig_bsz = x.shape[0] if shape_rule: x = x[0].unsqueeze(0) - timestep = timestep[0] + timestep = timestep[0].unsqueeze(0) out = self.structure_model(x, timestep, context if not shape_rule else cond) + if shape_rule: + out = out.repeat(orig_bsz, 1, 1, 1, 1) out.generation_mode = mode return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 4d97129eb..ad9881db7 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -295,7 +295,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): def simplify_fn(vertices, faces, target=100000): if vertices.shape[0] <= target: - return + return vertices, faces min_feat = vertices.min(dim=0)[0] max_feat = vertices.max(dim=0)[0] @@ -334,6 +334,19 @@ def simplify_fn(vertices, faces, target=100000): return final_vertices, final_faces def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): + is_batched = vertices.ndim == 3 + if is_batched: + batch_size = vertices.shape[0] + if batch_size > 1: + v_out, f_out = [], [] + for i in range(batch_size): + v, f = fill_holes_fn(vertices[i], faces[i], max_hole_perimeter) + v_out.append(v) + f_out.append(f) + return torch.stack(v_out), torch.stack(f_out) + + vertices = vertices.squeeze(0) + faces = faces.squeeze(0) device = vertices.device orig_vertices = vertices @@ -346,24 +359,23 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): ], dim=0) edges_sorted, _ = torch.sort(edges, dim=1) - unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True) - boundary_mask = counts == 1 boundary_edges_sorted = unique_edges[boundary_mask] if boundary_edges_sorted.shape[0] == 0: - return + if is_batched: + return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) + return orig_vertices, orig_faces + max_idx = vertices.shape[0] - _, inverse_indices, counts_packed = torch.unique( - torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1], - return_inverse=True, return_counts=True - ) + packed_edges_all = torch.sort(edges, dim=1).values + packed_edges_all = packed_edges_all[:, 0] * max_idx + packed_edges_all[:, 1] - boundary_packed_mask = counts_packed == 1 - is_boundary_edge = boundary_packed_mask[inverse_indices] + packed_boundary = boundary_edges_sorted[:, 0] * max_idx + boundary_edges_sorted[:, 1] + is_boundary_edge = torch.isin(packed_edges_all, packed_boundary) active_boundary_edges = edges[is_boundary_edge] adj = {} @@ -373,78 +385,61 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): loops = [] visited_edges = set() - - possible_starts = list(adj.keys()) - processed_nodes = set() - - for start_node in possible_starts: - if start_node in processed_nodes: - continue - - current_loop = [] - curr = start_node - + for start_node in list(adj.keys()): + if start_node in processed_nodes: continue + current_loop, curr = [], start_node while curr in adj: next_node = adj[curr] - if (curr, next_node) in visited_edges: - break - + if (curr, next_node) in visited_edges: break visited_edges.add((curr, next_node)) processed_nodes.add(curr) current_loop.append(curr) - curr = next_node - if curr == start_node: loops.append(current_loop) break - - if len(current_loop) > len(edges_np): - break + if len(current_loop) > len(edges_np): break if not loops: - return + if is_batched: return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) + return orig_vertices, orig_faces new_faces = [] - v_offset = vertices.shape[0] - valid_new_verts = [] for loop_indices in loops: - if len(loop_indices) < 3: - continue - + if len(loop_indices) < 3: continue loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) loop_verts = vertices[loop_tensor] - diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) perimeter = torch.norm(diffs, dim=1).sum() - if perimeter > max_hole_perimeter: - continue + if perimeter > max_hole_perimeter: continue center = loop_verts.mean(dim=0) valid_new_verts.append(center) - c_idx = v_offset v_offset += 1 num_v = len(loop_indices) for i in range(num_v): - v_curr = loop_indices[i] - v_next = loop_indices[(i + 1) % num_v] + v_curr, v_next = loop_indices[i], loop_indices[(i + 1) % num_v] new_faces.append([v_curr, v_next, c_idx]) if len(valid_new_verts) > 0: added_vertices = torch.stack(valid_new_verts, dim=0) added_faces = torch.tensor(new_faces, dtype=torch.long, device=device) + vertices = torch.cat([orig_vertices, added_vertices], dim=0) + faces = torch.cat([orig_faces, added_faces], dim=0) + else: + vertices, faces = orig_vertices, orig_faces - vertices_f = torch.cat([orig_vertices, added_vertices], dim=0) - faces_f = torch.cat([orig_faces, added_faces], dim=0) + if is_batched: + return vertices.unsqueeze(0), faces.unsqueeze(0) - return vertices_f, faces_f + return vertices, faces class PostProcessMesh(IO.ComfyNode): @classmethod @@ -454,8 +449,8 @@ class PostProcessMesh(IO.ComfyNode): category="latent/3d", inputs=[ IO.Mesh.Input("mesh"), - IO.Int.Input("simplify", default=100_000, min=0), # max? - IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0) + IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), # max? + IO.Float.Input("fill_holes_perimeter", default=0.003, min=0.0, step=0.0001) ], outputs=[ IO.Mesh.Output("output_mesh"), From ff04ef555854f524bce57c8f932500b67282844d Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 18 Feb 2026 21:52:40 +0200 Subject: [PATCH 23/59] fix run_conditioning --- comfy_extras/nodes_trellis2.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index ad9881db7..bf457135c 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -71,7 +71,14 @@ def smart_crop_square(image, mask, margin_ratio=0.1, bg_color=(128, 128, 128)): dst_x2 = dst_x1 + (src_x2 - src_x1) dst_y2 = dst_y1 + (src_y2 - src_y1) - canvas[:, dst_y1:dst_y2, dst_x1:dst_x2] = image[:, src_y1:src_y2, src_x1:src_x2] + img_crop = image[:, src_y1:src_y2, src_x1:src_x2] + mask_crop = mask[0, src_y1:src_y2, src_x1:src_x2] + + bg_val = torch.tensor(bg_color, device=image.device, dtype=image.dtype).view(-1, 1, 1) / 255.0 + + masked_crop = img_crop * mask_crop + bg_val * (1.0 - mask_crop) + + canvas[:, dst_y1:dst_y2, dst_x1:dst_x2] = masked_crop return canvas From 0a49718194d3f3bf42330cb88f8a1bbb7ade55fe Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 18 Feb 2026 21:54:05 +0200 Subject: [PATCH 24/59] .. --- comfy_extras/nodes_trellis2.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index bf457135c..817769d08 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -394,11 +394,13 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): visited_edges = set() processed_nodes = set() for start_node in list(adj.keys()): - if start_node in processed_nodes: continue + if start_node in processed_nodes: + continue current_loop, curr = [], start_node while curr in adj: next_node = adj[curr] - if (curr, next_node) in visited_edges: break + if (curr, next_node) in visited_edges: + break visited_edges.add((curr, next_node)) processed_nodes.add(curr) current_loop.append(curr) @@ -406,10 +408,12 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): if curr == start_node: loops.append(current_loop) break - if len(current_loop) > len(edges_np): break + if len(current_loop) > len(edges_np): + break if not loops: - if is_batched: return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) + if is_batched: + return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) return orig_vertices, orig_faces new_faces = [] @@ -417,13 +421,15 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): valid_new_verts = [] for loop_indices in loops: - if len(loop_indices) < 3: continue + if len(loop_indices) < 3: + continue loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) loop_verts = vertices[loop_tensor] diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) perimeter = torch.norm(diffs, dim=1).sum() - if perimeter > max_hole_perimeter: continue + if perimeter > max_hole_perimeter: + continue center = loop_verts.mean(dim=0) valid_new_verts.append(center) From b5feac202c45f5106ac91cffb361dcb0c411fd5e Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 18 Feb 2026 22:01:09 +0200 Subject: [PATCH 25/59] . --- comfy/ldm/trellis2/attention.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index e6aa50842..e8e401fd7 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -6,7 +6,7 @@ from comfy.ldm.trellis2.vae import VarLenTensor FLASH_ATTN_3_AVA = True try: - import flash_attn_interface as flash_attn_3 + import flash_attn_interface as flash_attn_3 # noqa: F401 except: FLASH_ATTN_3_AVA = False @@ -53,8 +53,6 @@ def scaled_dot_product_attention(*args, **kwargs): elif num_all_args == 3: out = flash_attn_3.flash_attn_func(q, k, v) elif optimized_attention.__name__ == 'attention_pytorch': - if 'sdpa' not in globals(): - from torch.nn.functional import scaled_dot_product_attention as sdpa if num_all_args == 1: q, k, v = qkv.unbind(dim=2) elif num_all_args == 2: From ee2b66a2f2b7d699ccc95d0a5b3bcb7fd5390814 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 19 Feb 2026 00:48:26 +0200 Subject: [PATCH 26/59] small fixes --- comfy/ldm/trellis2/model.py | 4 ++++ comfy_extras/nodes_trellis2.py | 15 +++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index eb410fe8b..b4fc15abc 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -828,6 +828,7 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): embeds = kwargs.get("embeds") mode = kwargs.get("generation_mode") + coords = kwargs.get("coords") transformer_options = kwargs.get("transformer_options") sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: @@ -836,6 +837,9 @@ class Trellis2(nn.Module): shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] + if mode in ["shape_generation", "texture_generation"]: + x = SparseTensor(feats=x, coords=coords) + if mode == "shape_generation": # TODO out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)])) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 817769d08..bd250e5f5 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,9 +1,7 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types import torch -from comfy.ldm.trellis2.model import SparseTensor import comfy.model_management -import comfy.model_patcher shape_slat_normalization = { "mean": torch.tensor([ @@ -205,7 +203,6 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): samples = samples.to(load_device) decoded = decoder(samples)>0 decoder.to(offload_device) - comfy.model_management.get_offload_stream out = Types.VOXEL(decoded.squeeze(1).float()) return IO.NodeOutput(out) @@ -253,10 +250,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): decoded = structure_output.data - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 - latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) - return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation"}) + latent = torch.randn(coords.shape[0], in_channels) + return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation", "coords": coords}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -275,9 +272,11 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): # TODO + decoded = structure_output.data + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 - latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1])) - return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation"}) + latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) + return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation", "coords": coords}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod From 9243ae347b59cad350e139f4c3bce5b01020f9a4 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 19 Feb 2026 02:01:26 +0200 Subject: [PATCH 27/59] added conditioning --- comfy_extras/nodes_trellis2.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index bd250e5f5..fc9a15cfa 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -2,6 +2,8 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types import torch import comfy.model_management +from PIL import Image +import numpy as np shape_slat_normalization = { "mean": torch.tensor([ @@ -226,6 +228,31 @@ class Trellis2Conditioning(IO.ComfyNode): @classmethod def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput: + + if image.ndim == 4: + image = image[0] + + # TODO + image = Image.fromarray(image.numpy()) + max_size = max(image.size) + scale = min(1, 1024 / max_size) + if scale < 1: + image = image.resize((int(image.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + + output_np = np.array(image) + alpha = output_np[:, :, 3] + bbox = np.argwhere(alpha > 0.8 * 255) + bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) + center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1) + bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 + output = image.crop(bbox) # type: ignore + output = np.array(output).astype(np.float32) / 255 + output = output[:, :, :3] * output[:, :, 3:4] + + image = torch.tensor(output) + # could make 1024 an option conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) embeds = conditioning["cond_1024"] # should add that From 6191cd86bfcaa1c18df68d2f0c932212dbc0a64d Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 19 Feb 2026 22:05:33 +0200 Subject: [PATCH 28/59] trellis2conditioning and a hidden bug --- comfy/ldm/trellis2/model.py | 13 ++++++++++--- comfy_extras/nodes_trellis2.py | 21 +++++---------------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index b4fc15abc..8579b0580 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -826,7 +826,12 @@ class Trellis2(nn.Module): self.guidance_interval_txt = [0.6, 0.9] def forward(self, x, timestep, context, **kwargs): + # FIXME: should find a way to distinguish between 512/1024 models + # currently assumes 1024 embeds = kwargs.get("embeds") + _, cond = context.chunk(2) + cond = embeds.chunk(2)[0] + context = torch.cat([torch.zeros_like(cond), cond]) mode = kwargs.get("generation_mode") coords = kwargs.get("coords") transformer_options = kwargs.get("transformer_options") @@ -837,12 +842,13 @@ class Trellis2(nn.Module): shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] - if mode in ["shape_generation", "texture_generation"]: + not_struct_mode = mode in ["shape_generation", "texture_generation"] + if not_struct_mode: x = SparseTensor(feats=x, coords=coords) if mode == "shape_generation": # TODO - out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)])) + out = self.img2shape(x, timestep, context) elif mode == "texture_generation": out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure @@ -855,5 +861,6 @@ class Trellis2(nn.Module): if shape_rule: out = out.repeat(orig_bsz, 1, 1, 1, 1) - out.generation_mode = mode + if not_struct_mode: + out = out.feats return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index fc9a15cfa..1683949a3 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -233,25 +233,14 @@ class Trellis2Conditioning(IO.ComfyNode): image = image[0] # TODO - image = Image.fromarray(image.numpy()) + image = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + image = Image.fromarray(image) max_size = max(image.size) scale = min(1, 1024 / max_size) if scale < 1: image = image.resize((int(image.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) - output_np = np.array(image) - alpha = output_np[:, :, 3] - bbox = np.argwhere(alpha > 0.8 * 255) - bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) - center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 - size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) - size = int(size * 1) - bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 - output = image.crop(bbox) # type: ignore - output = np.array(output).astype(np.float32) / 255 - output = output[:, :, :3] * output[:, :, 3:4] - - image = torch.tensor(output) + image = torch.tensor(np.array(image)).unsqueeze(0) # could make 1024 an option conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) @@ -276,7 +265,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): - decoded = structure_output.data + decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 latent = torch.randn(coords.shape[0], in_channels) @@ -299,7 +288,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): # TODO - decoded = structure_output.data + decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) From c5a750205d1a35d5ee8937f8997f8b2c10e37b10 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 20 Feb 2026 17:39:44 +0200 Subject: [PATCH 29/59] . --- comfy/ldm/trellis2/model.py | 14 +++++++++++--- comfy_extras/nodes_trellis2.py | 13 +++++++++---- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 8579b0580..5ff2a1ce0 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -8,6 +8,7 @@ from comfy.ldm.trellis2.attention import ( ) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder from comfy.ldm.flux.math import apply_rope, apply_rope1 +import builtins class SparseGELU(nn.GELU): def forward(self, input: VarLenTensor) -> VarLenTensor: @@ -481,6 +482,8 @@ class SLatFlowModel(nn.Module): if isinstance(cond, list): cond = VarLenTensor.from_tensor_list(cond) + dtype = next(self.input_layer.parameters()).dtype + x = x.to(dtype) h = self.input_layer(x) h = manual_cast(h, self.dtype) t_emb = self.t_embedder(t, out_dtype = t.dtype) @@ -832,8 +835,14 @@ class Trellis2(nn.Module): _, cond = context.chunk(2) cond = embeds.chunk(2)[0] context = torch.cat([torch.zeros_like(cond), cond]) - mode = kwargs.get("generation_mode") - coords = kwargs.get("coords") + mode = getattr(builtins, "TRELLIS_MODE", "structure_generation") + coords = getattr(builtins, "TRELLIS_COORDS", None) + if coords is not None: + x = x.squeeze(0) + not_struct_mode = True + else: + mode = "structure_generation" + not_struct_mode = False transformer_options = kwargs.get("transformer_options") sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: @@ -842,7 +851,6 @@ class Trellis2(nn.Module): shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] - not_struct_mode = mode in ["shape_generation", "texture_generation"] if not_struct_mode: x = SparseTensor(feats=x, coords=coords) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 1683949a3..c8d84fd23 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -4,6 +4,7 @@ import torch import comfy.model_management from PIL import Image import numpy as np +import builtins shape_slat_normalization = { "mean": torch.tensor([ @@ -268,8 +269,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 - latent = torch.randn(coords.shape[0], in_channels) - return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation", "coords": coords}) + latent = torch.randn(1, coords.shape[0], in_channels) + builtins.TRELLIS_MODE = "shape_generation" + builtins.TRELLIS_COORDS = coords + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -292,7 +295,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) - return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation", "coords": coords}) + builtins.TRELLIS_MODE = "texture_generation" + builtins.TRELLIS_COORDS = coords + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod @@ -312,7 +317,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): in_channels = 8 resolution = 16 latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) - return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "structure_generation"}) + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) def simplify_fn(vertices, faces, target=100000): From f3d4125e4904c427d26aa86c84d26b8dba48fe22 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 20 Feb 2026 20:16:49 +0200 Subject: [PATCH 30/59] code rabbit suggestions --- comfy/image_encoders/dino3.py | 6 +++--- comfy/image_encoders/dino3_large.json | 4 ++-- comfy/ldm/trellis2/cumesh.py | 2 +- comfy/ldm/trellis2/model.py | 8 +++----- comfy/model_detection.py | 4 ++-- 5 files changed, 11 insertions(+), 13 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index ce6b2edd9..e009a7291 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -188,7 +188,7 @@ class DINOv3ViTLayer(nn.Module): device, dtype, operations): super().__init__() - self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps) + self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations) self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) @@ -273,8 +273,8 @@ class DINOv3ViTModel(nn.Module): position_embeddings=position_embeddings, ) - self.norm = self.norm.to(hidden_states.device) - sequence_output = self.norm(hidden_states) + norm = self.norm.to(hidden_states.device) + sequence_output = norm(hidden_states) pooled_output = sequence_output[:, 0, :] return sequence_output, None, pooled_output, None diff --git a/comfy/image_encoders/dino3_large.json b/comfy/image_encoders/dino3_large.json index 53f761a25..b37b61dc8 100644 --- a/comfy/image_encoders/dino3_large.json +++ b/comfy/image_encoders/dino3_large.json @@ -18,6 +18,6 @@ "rope_theta": 100.0, "use_gated_mlp": false, "value_bias": true, - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225] + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225] } diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index 972fb13c3..cb067a32f 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -6,7 +6,7 @@ from typing import Dict, Callable NO_TRITION = False try: - allow_tf32 = torch.cuda.is_tf32_supported + allow_tf32 = torch.cuda.is_tf32_supported() except Exception: allow_tf32 = False try: diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 5ff2a1ce0..07cf86d30 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -306,10 +306,7 @@ class ModulatedSparseTransformerBlock(nn.Module): return x def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: - if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) - else: - return self._forward(x, mod) + return self._forward(x, mod) class ModulatedSparseTransformerCrossBlock(nn.Module): @@ -486,6 +483,7 @@ class SLatFlowModel(nn.Module): x = x.to(dtype) h = self.input_layer(x) h = manual_cast(h, self.dtype) + t = t.to(dtype) t_emb = self.t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) @@ -790,7 +788,7 @@ class SparseStructureFlowModel(nn.Module): return h def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0): - t_shifted /= 1000.0 + t_shifted = t_shifted / 1000.0 t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1)) t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear) t_new *= 1000.0 diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 375cb87b1..6cadc8af6 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -117,12 +117,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config["image_model"] = "trellis2" unet_config["init_txt_model"] = False - if '{}model.shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: + if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: unet_config["init_txt_model"] = True unet_config["resolution"] = 64 if metadata is not None: - if "is_512" in metadata and metadata["metadata"]: + if "is_512" in metadata: unet_config["resolution"] = 32 unet_config["num_heads"] = 12 From b3da8ed4c594744c5a9a8e14577b9a6898ada8f1 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 20 Feb 2026 21:13:13 +0200 Subject: [PATCH 31/59] coderabbit 2 --- comfy/ldm/trellis2/cumesh.py | 7 +++++-- comfy/ldm/trellis2/model.py | 9 ++++----- comfy_extras/nodes_trellis2.py | 37 +++++++++++++++++++++++++--------- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index cb067a32f..1be8408c6 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -207,11 +207,14 @@ class TorchHashMap: def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: flat = flat_keys.long() + if self._n == 0: + return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) idx = torch.searchsorted(self.sorted_keys, flat) - found = (idx < self._n) & (self.sorted_keys[idx] == flat) + idx_safe = torch.clamp(idx, max=self._n - 1) + found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat) out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) if found.any(): - out[found] = self.sorted_vals[idx[found]] + out[found] = self.sorted_vals[idx_safe[found]] return out diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 07cf86d30..ef1c25d33 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -8,7 +8,6 @@ from comfy.ldm.trellis2.attention import ( ) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder from comfy.ldm.flux.math import apply_rope, apply_rope1 -import builtins class SparseGELU(nn.GELU): def forward(self, input: VarLenTensor) -> VarLenTensor: @@ -829,19 +828,19 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): # FIXME: should find a way to distinguish between 512/1024 models # currently assumes 1024 + transformer_options = kwargs.get("transformer_options") embeds = kwargs.get("embeds") - _, cond = context.chunk(2) + #_, cond = context.chunk(2) # TODO cond = embeds.chunk(2)[0] context = torch.cat([torch.zeros_like(cond), cond]) - mode = getattr(builtins, "TRELLIS_MODE", "structure_generation") - coords = getattr(builtins, "TRELLIS_COORDS", None) + coords = transformer_options.get("coords", None) + mode = transformer_options.get("generation_mode", "structure_generation") if coords is not None: x = x.squeeze(0) not_struct_mode = True else: mode = "structure_generation" not_struct_mode = False - transformer_options = kwargs.get("transformer_options") sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: timestep *= 1000.0 diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index c8d84fd23..14f5484d6 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -4,7 +4,6 @@ import torch import comfy.model_management from PIL import Image import numpy as np -import builtins shape_slat_normalization = { "mean": torch.tensor([ @@ -258,21 +257,31 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): category="latent/3d", inputs=[ IO.Voxel.Input("structure_output"), + IO.Model.Input("model") ], outputs=[ IO.Latent.Output(), + IO.Model.Output() ] ) @classmethod - def execute(cls, structure_output): + def execute(cls, structure_output, model): decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 latent = torch.randn(1, coords.shape[0], in_channels) - builtins.TRELLIS_MODE = "shape_generation" - builtins.TRELLIS_COORDS = coords - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + model = model.clone() + if "transformer_options" not in model.model_options: + model.model_options = {} + else: + model.model_options = model.model_options.copy() + + model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() + + model.model_options["transformer_options"]["coords"] = coords + model.model_options["transformer_options"]["generation_mode"] = "shape_generation" + return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -285,19 +294,29 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ], outputs=[ IO.Latent.Output(), + IO.Model.Output() ] ) @classmethod - def execute(cls, structure_output): + def execute(cls, structure_output, model): # TODO decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) - builtins.TRELLIS_MODE = "texture_generation" - builtins.TRELLIS_COORDS = coords - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + model = model.clone() + if "transformer_options" not in model.model_options: + model.model_options = {} + else: + model.model_options = model.model_options.copy() + + model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() + + model.model_options["transformer_options"]["coords"] = coords + model.model_options["transformer_options"]["generation_mode"] = "shape_generation" + return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) + class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod From 1fde60b2bc67d39ab4177c1aabc828350245d5f9 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 20 Feb 2026 22:04:37 +0200 Subject: [PATCH 32/59] debugging --- comfy/image_encoders/dino3.py | 2 +- comfy/ldm/trellis2/model.py | 94 +++++++++------------------------- comfy_extras/nodes_trellis2.py | 4 +- 3 files changed, 27 insertions(+), 73 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index e009a7291..3ec7f8a04 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -175,7 +175,7 @@ class DINOv3ViTEmbeddings(nn.Module): cls_token = self.cls_token.expand(batch_size, -1, -1) register_tokens = self.register_tokens.expand(batch_size, -1, -1) - device = patch_embeddings + device = patch_embeddings.device cls_token = cls_token.to(device) register_tokens = register_tokens.to(device) embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index ef1c25d33..7f16c4d41 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -201,6 +201,8 @@ class SparseMultiHeadAttention(nn.Module): def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor: if self._type == "self": + dtype = next(self.to_qkv.parameters()).dtype + x = x.to(dtype) qkv = self._linear(self.to_qkv, x) qkv = self._fused_pre(qkv, num_fused=3) if self.qk_rms_norm or self.use_rope: @@ -243,71 +245,6 @@ class SparseMultiHeadAttention(nn.Module): h = self._linear(self.to_out, h) return h -class ModulatedSparseTransformerBlock(nn.Module): - def __init__( - self, - channels: int, - num_heads: int, - mlp_ratio: float = 4.0, - attn_mode: Literal["full", "swin"] = "full", - window_size: Optional[int] = None, - shift_window: Optional[Tuple[int, int, int]] = None, - use_checkpoint: bool = False, - use_rope: bool = False, - rope_freq: Tuple[float, float] = (1.0, 10000.0), - qk_rms_norm: bool = False, - qkv_bias: bool = True, - share_mod: bool = False, - ): - super().__init__() - self.use_checkpoint = use_checkpoint - self.share_mod = share_mod - self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.attn = SparseMultiHeadAttention( - channels, - num_heads=num_heads, - attn_mode=attn_mode, - window_size=window_size, - shift_window=shift_window, - qkv_bias=qkv_bias, - use_rope=use_rope, - rope_freq=rope_freq, - qk_rms_norm=qk_rms_norm, - ) - self.mlp = SparseFeedForwardNet( - channels, - mlp_ratio=mlp_ratio, - ) - if not share_mod: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) - ) - else: - self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) - - def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: - if self.share_mod: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) - else: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) - h = x.replace(self.norm1(x.feats)) - h = h * (1 + scale_msa) + shift_msa - h = self.attn(h) - h = h * gate_msa - x = x + h - h = x.replace(self.norm2(x.feats)) - h = h * (1 + scale_mlp) + shift_mlp - h = self.mlp(h) - h = h * gate_mlp - x = x + h - return x - - def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: - return self._forward(x, mod) - - class ModulatedSparseTransformerCrossBlock(nn.Module): """ Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. @@ -483,15 +420,13 @@ class SLatFlowModel(nn.Module): h = self.input_layer(x) h = manual_cast(h, self.dtype) t = t.to(dtype) - t_emb = self.t_embedder(t, out_dtype = t.dtype) + t_embedder = self.t_embedder.to(dtype) + t_emb = t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) t_emb = manual_cast(t_emb, self.dtype) cond = manual_cast(cond, self.dtype) - if self.pe_mode == "ape": - pe = self.pos_embedder(h.coords[:, 1:]) - h = h + manual_cast(pe, self.dtype) for block in self.blocks: h = block(h, t_emb, cond) @@ -849,7 +784,24 @@ class Trellis2(nn.Module): txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] if not_struct_mode: - x = SparseTensor(feats=x, coords=coords) + B, N, C = x.shape + + if mode == "shape_generation": + feats_flat = x.reshape(-1, C) + + # 3. inflate coords [N, 4] -> [B*N, 4] + coords_list = [] + for i in range(B): + c = coords.clone() + c[:, 0] = i + coords_list.append(c) + + batched_coords = torch.cat(coords_list, dim=0) + else: # TODO: texture + # may remove the else if texture doesn't require special handling + batched_coords = coords + feats_flat = x + x = SparseTensor(feats=feats_flat, coords=batched_coords) if mode == "shape_generation": # TODO @@ -868,4 +820,6 @@ class Trellis2(nn.Module): if not_struct_mode: out = out.feats + if mode == "shape_generation": + out = out.view(B, N, -1) return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 14f5484d6..623430b9e 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -238,9 +238,9 @@ class Trellis2Conditioning(IO.ComfyNode): max_size = max(image.size) scale = min(1, 1024 / max_size) if scale < 1: - image = image.resize((int(image.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS) - image = torch.tensor(np.array(image)).unsqueeze(0) + image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255 # could make 1024 an option conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) From 253ee4c02c8396555cff8bc07c7983fa2ccd0074 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 22 Feb 2026 01:25:10 +0200 Subject: [PATCH 33/59] fixes --- comfy/image_encoders/dino3.py | 5 +---- comfy/ldm/trellis2/cumesh.py | 8 +++++--- comfy/ldm/trellis2/model.py | 11 +++++++---- comfy/ldm/trellis2/vae.py | 2 ++ comfy_extras/nodes_trellis2.py | 17 +++++++++++------ 5 files changed, 26 insertions(+), 17 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index 3ec7f8a04..40ece19ed 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -44,9 +44,6 @@ class DINOv3ViTAttention(nn.Module): self.num_heads = num_attention_heads self.head_dim = self.embed_dim // self.num_heads - self.scaling = self.head_dim**-0.5 - self.is_causal = False - self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) @@ -251,7 +248,7 @@ class DINOv3ViTModel(nn.Module): intermediate_size=intermediate_size,num_attention_heads = num_attention_heads, dtype=dtype, device=device, operations=operations) for _ in range(num_hidden_layers)]) - self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device) + self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device) def get_input_embeddings(self): return self.embeddings.patch_embeddings diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index 1be8408c6..8f677ce24 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -4,7 +4,7 @@ import math import torch from typing import Dict, Callable -NO_TRITION = False +NO_TRITON = False try: allow_tf32 = torch.cuda.is_tf32_supported() except Exception: @@ -115,8 +115,8 @@ try: allow_tf32=allow_tf32, ) return output -except: - NO_TRITION = True +except Exception: + NO_TRITON = True def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device): # offsets in same order as CUDA kernel @@ -364,6 +364,8 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2( def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation): + if NO_TRITON: # TODO + raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.") if len(shape) == 5: N, C, W, H, D = shape else: diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 7f16c4d41..7cc1c1678 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -697,8 +697,6 @@ class SparseStructureFlowModel(nn.Module): def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3) - assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ - f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous() @@ -746,7 +744,8 @@ class Trellis2(nn.Module): super().__init__() self.dtype = dtype # for some reason it passes num_heads = -1 - num_heads = 12 + if num_heads == -1: + num_heads = 12 args = { "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, @@ -763,8 +762,10 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): # FIXME: should find a way to distinguish between 512/1024 models # currently assumes 1024 - transformer_options = kwargs.get("transformer_options") + transformer_options = kwargs.get("transformer_options", {}) embeds = kwargs.get("embeds") + if embeds is None: + raise ValueError("Trellis2.forward requires 'embeds' in kwargs") #_, cond = context.chunk(2) # TODO cond = embeds.chunk(2)[0] context = torch.cat([torch.zeros_like(cond), cond]) @@ -807,6 +808,8 @@ class Trellis2(nn.Module): # TODO out = self.img2shape(x, timestep, context) elif mode == "texture_generation": + if self.shape2txt is None: + raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure timestep = timestep_reshift(timestep) diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 57bf78346..c6ea5deb2 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -1522,6 +1522,8 @@ class Vae(nn.Module): return self.shape_dec(slat, return_subs=True) def decode_tex_slat(self, slat, subs): + if self.txt_dec is None: + raise ValueError("Checkpoint doesn't include texture model") return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5 @torch.no_grad() diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 623430b9e..c688c343d 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,9 +1,11 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types -import torch +import torch.nn.functional as TF import comfy.model_management from PIL import Image import numpy as np +import torch +import copy shape_slat_normalization = { "mean": torch.tensor([ @@ -145,11 +147,11 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): def execute(cls, samples, vae, resolution): vae = vae.first_stage_model samples = samples["samples"] - std = shape_slat_normalization["std"] - mean = shape_slat_normalization["mean"] + std = shape_slat_normalization["std"].to(samples) + mean = shape_slat_normalization["mean"].to(samples) samples = samples * std + mean - mesh, subs = vae.decode_shape_slat(resolution, samples) + mesh, subs = vae.decode_shape_slat(samples, resolution) return IO.NodeOutput(mesh, subs) class VaeDecodeTextureTrellis(IO.ComfyNode): @@ -172,8 +174,8 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): def execute(cls, samples, vae, shape_subs): vae = vae.first_stage_model samples = samples["samples"] - std = tex_slat_normalization["std"] - mean = tex_slat_normalization["mean"] + std = tex_slat_normalization["std"].to(samples) + mean = tex_slat_normalization["mean"].to(samples) samples = samples * std + mean mesh = vae.decode_tex_slat(samples, shape_subs) @@ -239,6 +241,8 @@ class Trellis2Conditioning(IO.ComfyNode): scale = min(1, 1024 / max_size) if scale < 1: image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS) + new_h, new_w = int(mask.shape[-2] * scale), int(mask.shape[-1] * scale) + mask = TF.interpolate(mask.unsqueeze(0).float(), size=(new_h, new_w), mode='nearest').squeeze(0) image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255 @@ -510,6 +514,7 @@ class PostProcessMesh(IO.ComfyNode): ) @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): + mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces if fill_holes_perimeter != 0.0: From c9f5c788a733c68029be33141197706b07f9f669 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 22 Feb 2026 23:47:49 +0200 Subject: [PATCH 34/59] bunch of fixes --- comfy/image_encoders/dino3.py | 3 +- comfy/ldm/trellis2/attention.py | 82 ++++++++++++++++++++++----------- comfy/ldm/trellis2/cumesh.py | 11 ++++- comfy/ldm/trellis2/model.py | 2 +- comfy/ldm/trellis2/vae.py | 21 +++++---- comfy_extras/nodes_trellis2.py | 18 ++++---- 6 files changed, 86 insertions(+), 51 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index 40ece19ed..1bf404498 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -2,6 +2,7 @@ import math import torch import torch.nn as nn +import comfy.model_management from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale @@ -225,7 +226,7 @@ class DINOv3ViTLayer(nn.Module): class DINOv3ViTModel(nn.Module): def __init__(self, config, dtype, device, operations): super().__init__() - if dtype == torch.float16: + if dtype == torch.float16 and comfy.model_management.should_use_bf16(device, prioritize_performance=False): dtype = torch.bfloat16 num_hidden_layers = config["num_hidden_layers"] hidden_size = config["hidden_size"] diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index e8e401fd7..19de93b96 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -3,12 +3,56 @@ import math from comfy.ldm.modules.attention import optimized_attention from typing import Tuple, Union, List from comfy.ldm.trellis2.vae import VarLenTensor +import comfy.ops + + +# replica of the seedvr2 code +def var_attn_arg(kwargs): + cu_seqlens_q = kwargs.get("cu_seqlens_q", None) + max_seqlen_q = kwargs.get("max_seqlen_q", None) + cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) or kwargs.get("cu_seqlens_kv", cu_seqlens_q) + max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) or kwargs.get("max_kv_seqlen", max_seqlen_q) + assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True" + return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k + +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + var_length = True + if var_length: + cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs) + if not skip_reshape: + # assumes 2D q, k,v [total_tokens, embed_dim] + total_tokens, embed_dim = q.shape + head_dim = embed_dim // heads + q = q.view(total_tokens, heads, head_dim) + k = k.view(k.shape[0], heads, head_dim) + v = v.view(v.shape[0], heads, head_dim) + + b = q.size(0) + dim_head = q.shape[-1] + q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long()) + k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long()) + v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long()) + + mask = None + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if mask is not None: + if mask.ndim == 2: + mask = mask.unsqueeze(0) + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + if var_length: + return out.contiguous().transpose(1, 2).values() + if not skip_output_reshape: + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) + return out -FLASH_ATTN_3_AVA = True -try: - import flash_attn_interface as flash_attn_3 # noqa: F401 -except: - FLASH_ATTN_3_AVA = False # TODO repalce with optimized attention def scaled_dot_product_attention(*args, **kwargs): @@ -40,18 +84,10 @@ def scaled_dot_product_attention(*args, **kwargs): k, v = kv.unbind(dim=2) #out = xops.memory_efficient_attention(q, k, v) out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) - elif optimized_attention.__name__ == 'attention_flash' and not FLASH_ATTN_3_AVA: + elif optimized_attention.__name__ == 'attention_flash': if num_all_args == 2: k, v = kv.unbind(dim=2) out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) - elif optimized_attention.__name__ == 'attention_flash': # TODO - if 'flash_attn_3' not in globals(): - import flash_attn_interface as flash_attn_3 - if num_all_args == 2: - k, v = kv.unbind(dim=2) - out = flash_attn_3.flash_attn_func(q, k, v) - elif num_all_args == 3: - out = flash_attn_3.flash_attn_func(q, k, v) elif optimized_attention.__name__ == 'attention_pytorch': if num_all_args == 1: q, k, v = qkv.unbind(dim=2) @@ -238,24 +274,16 @@ def sparse_scaled_dot_product_attention(*args, **kwargs): out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) elif num_all_args == 3: out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) - elif optimized_attention.__name__ == 'flash_attn_3': # TODO - if 'flash_attn_3' not in globals(): - import flash_attn_interface as flash_attn_3 + + elif optimized_attention.__name__ == "attention_pytorch": cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) if num_all_args == 1: q, k, v = qkv.unbind(dim=1) - cu_seqlens_kv = cu_seqlens_q.clone() - max_q_seqlen = max_kv_seqlen = max(q_seqlen) elif num_all_args == 2: k, v = kv.unbind(dim=1) - cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) - max_q_seqlen = max(q_seqlen) - max_kv_seqlen = max(kv_seqlen) - elif num_all_args == 3: - cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) - max_q_seqlen = max(q_seqlen) - max_kv_seqlen = max(kv_seqlen) - out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen) + out = attention_pytorch(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) if s is not None: return s.replace(out) diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index 8f677ce24..ea069a465 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -3,6 +3,7 @@ import math import torch from typing import Dict, Callable +import logging NO_TRITON = False try: @@ -366,6 +367,10 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2( def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation): if NO_TRITON: # TODO raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.") + if feats.shape[0] == 0: + logging.warning("Found feats to be empty!") + Co = weight.shape[0] + return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None if len(shape) == 5: N, C, W, H, D = shape else: @@ -427,9 +432,11 @@ class Voxel: voxel_size: float, coords: torch.Tensor = None, attrs: torch.Tensor = None, - layout: Dict = {}, - device: torch.device = 'cuda' + layout = None, + device: torch.device = None ): + if layout is None: + layout = {} self.origin = torch.tensor(origin, dtype=torch.float32, device=device) self.voxel_size = voxel_size self.coords = coords diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 7cc1c1678..52242b4e0 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -630,7 +630,6 @@ class SparseStructureFlowModel(nn.Module): mlp_ratio: float = 4, pe_mode: Literal["ape", "rope"] = "rope", rope_freq: Tuple[float, float] = (1.0, 10000.0), - dtype: str = 'float32', use_checkpoint: bool = False, share_mod: bool = False, initialization: str = 'vanilla', @@ -638,6 +637,7 @@ class SparseStructureFlowModel(nn.Module): qk_rms_norm_cross: bool = False, operations=None, device = None, + dtype = torch.float32, **kwargs ): super().__init__() diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index c6ea5deb2..36e2f3df5 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -1004,7 +1004,6 @@ class SparseUnetVaeEncoder(nn.Module): self.model_channels = model_channels self.num_blocks = num_blocks self.dtype = torch.float16 if use_fp16 else torch.float32 - self.dtype = torch.float16 if use_fp16 else torch.float32 self.input_layer = SparseLinear(in_channels, model_channels[0]) self.to_latent = SparseLinear(model_channels[-1], 2 * latent_channels) @@ -1247,24 +1246,26 @@ def flexible_dual_grid_to_mesh( hashmap_builder=None, # optional callable for building/caching a TorchHashMap ): - if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset"): + device = coords.device + if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset") \ + or flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset.device != device: flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset = torch.tensor([ [[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis [[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis [[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis - ], dtype=torch.int, device=coords.device).unsqueeze(0) - if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1"): - flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=coords.device, requires_grad=False) - if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2"): - flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=coords.device, requires_grad=False) - if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train"): - flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=coords.device, requires_grad=False) + ], dtype=torch.int, device=device).unsqueeze(0) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1") or flexible_dual_grid_to_mesh.quad_split_1.device != device: + flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2") or flexible_dual_grid_to_mesh.quad_split_2.device != device: + flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train") or flexible_dual_grid_to_mesh.quad_split_train.device != device: + flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=device, requires_grad=False) # AABB if isinstance(aabb, (list, tuple)): aabb = np.array(aabb) if isinstance(aabb, np.ndarray): - aabb = torch.tensor(aabb, dtype=torch.float32, device=coords.device) + aabb = torch.tensor(aabb, dtype=torch.float32, device=device) # Voxel size if voxel_size is not None: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index c688c343d..2b44a19eb 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -276,12 +276,11 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): in_channels = 32 latent = torch.randn(1, coords.shape[0], in_channels) model = model.clone() - if "transformer_options" not in model.model_options: - model.model_options = {} + model.model_options = model.model_options.copy() + if "transformer_options" in model.model_options: + model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() else: - model.model_options = model.model_options.copy() - - model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() + model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords model.model_options["transformer_options"]["generation_mode"] = "shape_generation" @@ -310,12 +309,11 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): in_channels = 32 latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) model = model.clone() - if "transformer_options" not in model.model_options: - model.model_options = {} + model.model_options = model.model_options.copy() + if "transformer_options" in model.model_options: + model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() else: - model.model_options = model.model_options.copy() - - model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() + model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords model.model_options["transformer_options"]["generation_mode"] = "shape_generation" From 2a27c3b41738b375d2820ddde3dc512cb8f8d2b4 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 23 Feb 2026 01:40:56 +0200 Subject: [PATCH 35/59] progressing --- comfy/image_encoders/dino3.py | 5 ++++- comfy/ldm/trellis2/attention.py | 17 ++++++++++++++--- comfy/ldm/trellis2/model.py | 10 +++++----- comfy_extras/nodes_trellis2.py | 11 ++++++++++- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index 1bf404498..ff17d78d6 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -226,8 +226,11 @@ class DINOv3ViTLayer(nn.Module): class DINOv3ViTModel(nn.Module): def __init__(self, config, dtype, device, operations): super().__init__() - if dtype == torch.float16 and comfy.model_management.should_use_bf16(device, prioritize_performance=False): + use_bf16 = comfy.model_management.should_use_bf16(device, prioritize_performance=True) + if dtype == torch.float16 and use_bf16: dtype = torch.bfloat16 + elif dtype == torch.float16 and not use_bf16: + dtype = torch.float32 num_hidden_layers = config["num_hidden_layers"] hidden_size = config["hidden_size"] num_attention_heads = config["num_attention_heads"] diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 19de93b96..0b9c12294 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -10,8 +10,8 @@ import comfy.ops def var_attn_arg(kwargs): cu_seqlens_q = kwargs.get("cu_seqlens_q", None) max_seqlen_q = kwargs.get("max_seqlen_q", None) - cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) or kwargs.get("cu_seqlens_kv", cu_seqlens_q) - max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) or kwargs.get("max_kv_seqlen", max_seqlen_q) + cu_seqlens_k = kwargs.get("cu_seqlens_kv", cu_seqlens_q) + max_seqlen_k = kwargs.get("max_kv_seqlen", max_seqlen_q) assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True" return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k @@ -183,6 +183,7 @@ def calc_window_partition( def sparse_scaled_dot_product_attention(*args, **kwargs): + q=None arg_names_dict = { 1: ['qkv'], 2: ['q', 'kv'], @@ -250,6 +251,12 @@ def sparse_scaled_dot_product_attention(*args, **kwargs): k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + # TODO: change + if q is not None: + heads = q + else: + heads = qkv + heads = heads.shape[2] if optimized_attention.__name__ == 'attention_xformers': if 'xops' not in globals(): import xformers.ops as xops @@ -279,11 +286,15 @@ def sparse_scaled_dot_product_attention(*args, **kwargs): cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) if num_all_args in [2, 3]: cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + else: + cu_seqlens_kv = cu_seqlens_q if num_all_args == 1: q, k, v = qkv.unbind(dim=1) elif num_all_args == 2: k, v = kv.unbind(dim=1) - out = attention_pytorch(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + out = attention_pytorch(q, k, v, heads=heads,cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max(q_seqlen), max_kv_seqlen=max(kv_seqlen), + skip_reshape=True, skip_output_reshape=True) if s is not None: return s.replace(out) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 52242b4e0..a565ec37e 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -232,6 +232,8 @@ class SparseMultiHeadAttention(nn.Module): else: q = self._linear(self.to_q, x) q = self._reshape_chs(q, (self.num_heads, -1)) + dtype = next(self.to_kv.parameters()).dtype + context = context.to(dtype) kv = self._linear(self.to_kv, context) kv = self._fused_pre(kv, num_fused=2) if self.qk_rms_norm: @@ -760,15 +762,13 @@ class Trellis2(nn.Module): self.guidance_interval_txt = [0.6, 0.9] def forward(self, x, timestep, context, **kwargs): - # FIXME: should find a way to distinguish between 512/1024 models - # currently assumes 1024 transformer_options = kwargs.get("transformer_options", {}) embeds = kwargs.get("embeds") if embeds is None: raise ValueError("Trellis2.forward requires 'embeds' in kwargs") - #_, cond = context.chunk(2) # TODO - cond = embeds.chunk(2)[0] - context = torch.cat([torch.zeros_like(cond), cond]) + is_1024 = self.img2shape.resolution == 1024 + if is_1024: + context = embeds coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") if coords is not None: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 2b44a19eb..1b43f7f62 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -2,6 +2,7 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types import torch.nn.functional as TF import comfy.model_management +from comfy.utils import ProgressBar from PIL import Image import numpy as np import torch @@ -250,7 +251,7 @@ class Trellis2Conditioning(IO.ComfyNode): conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) embeds = conditioning["cond_1024"] # should add that positive = [[conditioning["cond_512"], {"embeds": embeds}]] - negative = [[conditioning["neg_cond"], {"embeds": embeds}]] + negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]] return IO.NodeOutput(positive, negative) class EmptyShapeLatentTrellis2(IO.ComfyNode): @@ -512,15 +513,23 @@ class PostProcessMesh(IO.ComfyNode): ) @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): + bar = ProgressBar(2) mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces if fill_holes_perimeter != 0.0: verts, faces = fill_holes_fn(verts, faces, max_hole_perimeter=fill_holes_perimeter) + bar.update(1) + else: + bar.update(1) if simplify != 0: verts, faces = simplify_fn(verts, faces, simplify) + bar.update(1) + else: + bar.update(1) + # potentially adding laplacian smoothing mesh.vertices = verts mesh.faces = faces From a2c8a7aab5e33df51942454f59fa16bb19c8a9c7 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 24 Feb 2026 14:53:54 +0200 Subject: [PATCH 36/59] . --- comfy/ldm/trellis2/model.py | 4 ++-- comfy_extras/nodes_trellis2.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index a565ec37e..fb5276f94 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -772,7 +772,7 @@ class Trellis2(nn.Module): coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") if coords is not None: - x = x.squeeze(0) + x = x.squeeze(-1).transpose(1, 2) not_struct_mode = True else: mode = "structure_generation" @@ -824,5 +824,5 @@ class Trellis2(nn.Module): if not_struct_mode: out = out.feats if mode == "shape_generation": - out = out.view(B, N, -1) + out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 1b43f7f62..f40ff5161 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -275,7 +275,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 - latent = torch.randn(1, coords.shape[0], in_channels) + # image like format + latent = torch.randn(1, in_channels, coords.shape[0], 1) model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: From f31c2e1d1d7359f05995804f636444313b794068 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 25 Feb 2026 04:26:33 +0200 Subject: [PATCH 37/59] vae shape decode fixes --- comfy/ldm/trellis2/model.py | 2 +- comfy/ldm/trellis2/vae.py | 30 ++++++++++++++++++++---------- comfy/sd.py | 4 ++-- comfy_extras/nodes_trellis2.py | 15 ++++++++++++--- 4 files changed, 35 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index fb5276f94..45740faea 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -802,7 +802,7 @@ class Trellis2(nn.Module): # may remove the else if texture doesn't require special handling batched_coords = coords feats_flat = x - x = SparseTensor(feats=feats_flat, coords=batched_coords) + x = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) if mode == "shape_generation": # TODO diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 36e2f3df5..0b1975092 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -1,11 +1,12 @@ import math import torch -import torch.nn as nn -from typing import List, Any, Dict, Optional, overload, Union, Tuple -from fractions import Fraction -import torch.nn.functional as F -from dataclasses import dataclass import numpy as np +import torch.nn as nn +import comfy.model_management +import torch.nn.functional as F +from fractions import Fraction +from dataclasses import dataclass +from typing import List, Any, Dict, Optional, overload, Union, Tuple from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d @@ -55,11 +56,12 @@ def sparse_conv3d_forward(self, x): Co, Kd, Kh, Kw, Ci = self.weight.shape neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}' neighbor_cache = x.get_spatial_cache(neighbor_cache_key) + x = x.to(self.weight.dtype).to(self.weight.device) out, neighbor_cache_ = sparse_submanifold_conv3d( x.feats, x.coords, - torch.Size([*x.shape, *x.spatial_shape]), + x.spatial_shape, self.weight, self.bias, neighbor_cache, @@ -100,7 +102,8 @@ class SparseConvNeXtBlock3d(nn.Module): def _forward(self, x): h = self.conv(x) - h = h.replace(self.norm(h.feats)) + norm = self.norm.to(torch.float32) + h = h.replace(norm(h.feats)) h = h.replace(self.mlp(h.feats)) return h + x @@ -208,13 +211,15 @@ class SparseResBlockC2S3d(nn.Module): def forward(self, x, subdiv = None): if self.pred_subdiv: subdiv = self.to_subdiv(x) - h = x.replace(self.norm1(x.feats)) + norm1 = self.norm1.to(torch.float32) + norm2 = self.norm2.to(torch.float32) + h = x.replace(norm1(x.feats)) h = h.replace(F.silu(h.feats)) h = self.conv1(h) subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None h = self.updown(h, subdiv_binarized) x = self.updown(x, subdiv_binarized) - h = h.replace(self.norm2(h.feats)) + h = h.replace(norm2(h.feats)) h = h.replace(F.silu(h.feats)) h = self.conv2(h) h = h + self.skip_connection(x) @@ -1139,6 +1144,9 @@ class SparseUnetVaeDecoder(nn.Module): def forward(self, x: SparseTensor, guide_subs: Optional[List[SparseTensor]] = None, return_subs: bool = False) -> SparseTensor: + dtype = next(self.from_latent.parameters()).dtype + device = next(self.from_latent.parameters()).device + x.feats = x.feats.to(dtype).to(device) h = self.from_latent(x) h = h.type(self.dtype) subs = [] @@ -1152,7 +1160,7 @@ class SparseUnetVaeDecoder(nn.Module): h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None) else: h = block(h) - h = h.type(x.dtype) + h = h.type(x.feats.dtype) h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) h = self.output_layer(h) if return_subs: @@ -1520,6 +1528,8 @@ class Vae(nn.Module): def decode_shape_slat(self, slat, resolution: int): self.shape_dec.set_resolution(resolution) + device = comfy.model_management.get_torch_device() + self.shape_dec = self.shape_dec.to(device) return self.shape_dec(slat, return_subs=True) def decode_tex_slat(self, slat, subs): diff --git a/comfy/sd.py b/comfy/sd.py index fecd16c88..f9898b0de 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -513,8 +513,8 @@ class VAE: init_txt_model = True self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] # TODO - self.memory_used_decode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model) elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index f40ff5161..96510e916 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,8 +1,9 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types +from comfy.ldm.trellis2.vae import SparseTensor +from comfy.utils import ProgressBar import torch.nn.functional as TF import comfy.model_management -from comfy.utils import ProgressBar from PIL import Image import numpy as np import torch @@ -135,6 +136,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): category="latent/3d", inputs=[ IO.Latent.Input("samples"), + IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), IO.Int.Input("resolution", tooltip="Shape Generation Resolution"), ], @@ -145,11 +147,14 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples, vae, resolution): + def execute(cls, samples, structure_output, vae, resolution): vae = vae.first_stage_model + decoded = structure_output.data.unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] std = shape_slat_normalization["std"].to(samples) mean = shape_slat_normalization["mean"].to(samples) + samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean mesh, subs = vae.decode_shape_slat(samples, resolution) @@ -163,6 +168,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): category="latent/3d", inputs=[ IO.Latent.Input("samples"), + IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), IO.AnyType.Input("shape_subs"), ], @@ -172,11 +178,14 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples, vae, shape_subs): + def execute(cls, samples, structure_output, vae, shape_subs): vae = vae.first_stage_model + decoded = structure_output.data.unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] std = tex_slat_normalization["std"].to(samples) mean = tex_slat_normalization["mean"].to(samples) + samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean mesh = vae.decode_tex_slat(samples, shape_subs) From 39270fdca941fee1d5a4df0c930e28c90505a0ad Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 25 Feb 2026 23:54:03 +0200 Subject: [PATCH 38/59] removed unnecessary code + optimizations + progres --- comfy/ldm/trellis2/cumesh.py | 177 +++++++++------------------------ comfy/ldm/trellis2/vae.py | 144 +++------------------------ comfy_extras/nodes_trellis2.py | 18 +++- 3 files changed, 75 insertions(+), 264 deletions(-) diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index ea069a465..047e785ff 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -2,7 +2,7 @@ import math import torch -from typing import Dict, Callable +from typing import Callable import logging NO_TRITON = False @@ -201,13 +201,13 @@ class TorchHashMap: def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): device = keys.device # use long for searchsorted - self.sorted_keys, order = torch.sort(keys.long()) - self.sorted_vals = values.long()[order] + self.sorted_keys, order = torch.sort(keys.to(torch.long)) + self.sorted_vals = values.to(torch.long)[order] self.default_value = torch.tensor(default_value, dtype=torch.long, device=device) self._n = self.sorted_keys.numel() def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: - flat = flat_keys.long() + flat = flat_keys.to(torch.long) if self._n == 0: return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) idx = torch.searchsorted(self.sorted_keys, flat) @@ -225,44 +225,35 @@ def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map): device = neighbor_map.device N, V = neighbor_map.shape + sentinel = UINT32_SENTINEL - neigh = neighbor_map.to(torch.long) - sentinel = torch.tensor(UINT32_SENTINEL, dtype=torch.long, device=device) - - - neigh_map_T = neigh.t().reshape(-1) - + neigh_map_T = neighbor_map.t().reshape(-1) neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32) - mask = (neigh != sentinel).to(torch.long) + mask = (neighbor_map != sentinel).to(torch.long) + gray_code = torch.zeros(N, dtype=torch.long, device=device) - powers = (1 << torch.arange(V, dtype=torch.long, device=device)) + for v in range(V): + gray_code |= (mask[:, v] << v) - gray_long = (mask * powers).sum(dim=1) - - gray_code = gray_long.to(torch.int32) - - binary_long = gray_long.clone() + binary_code = gray_code.clone() for v in range(1, V): - binary_long ^= (gray_long >> v) - binary_code = binary_long.to(torch.int32) + binary_code ^= (gray_code >> v) sorted_idx = torch.argsort(binary_code) - prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T.to(torch.int32), dim=0) # (V*N,) + prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T, dim=0) total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0 if total_valid_signal > 0: + pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0] + to = (prefix_sum_neighbor_mask[pos] - 1).long() + valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device) valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device) - pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0] - - to = (prefix_sum_neighbor_mask[pos] - 1).to(torch.long) - valid_signal_i[to] = (pos % N).to(torch.long) - valid_signal_o[to] = neigh_map_T[pos].to(torch.long) else: valid_signal_i = torch.empty((0,), dtype=torch.long, device=device) @@ -272,9 +263,7 @@ def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map): seg[0] = 0 if V > 0: idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1 - seg[1:] = prefix_sum_neighbor_mask[idxs].to(torch.long) - else: - pass + seg[1:] = prefix_sum_neighbor_mask[idxs] return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg @@ -295,40 +284,41 @@ def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor: def neighbor_map_post_process_for_masked_implicit_gemm_2( - gray_code: torch.Tensor, # [N], int32-like (non-negative) - sorted_idx: torch.Tensor, # [N], long (indexing into gray_code) + gray_code: torch.Tensor, + sorted_idx: torch.Tensor, block_size: int ): device = gray_code.device N = gray_code.numel() - - # num of blocks (same as CUDA) num_blocks = (N + block_size - 1) // block_size - # Ensure dtypes - gray_long = gray_code.to(torch.int64) # safer to OR in 64-bit then cast - sorted_idx = sorted_idx.to(torch.long) - - # 1) Group gray_code by blocks and compute OR across each block - # pad the last block with zeros if necessary so we can reshape pad = num_blocks * block_size - N if pad > 0: - pad_vals = torch.zeros((pad,), dtype=torch.int64, device=device) - gray_padded = torch.cat([gray_long[sorted_idx], pad_vals], dim=0) + pad_vals = torch.zeros((pad,), dtype=torch.int32, device=device) + gray_padded = torch.cat([gray_code[sorted_idx], pad_vals], dim=0) else: - gray_padded = gray_long[sorted_idx] + gray_padded = gray_code[sorted_idx] - # reshape to (num_blocks, block_size) and compute bitwise_or across dim=1 - gray_blocks = gray_padded.view(num_blocks, block_size) # each row = block entries - # reduce with bitwise_or - reduced_code = gray_blocks[:, 0].clone() - for i in range(1, block_size): - reduced_code |= gray_blocks[:, i] - reduced_code = reduced_code.to(torch.int32) # match CUDA int32 + gray_blocks = gray_padded.view(num_blocks, block_size) + + reduced_code = gray_blocks + while reduced_code.shape[1] > 1: + half = reduced_code.shape[1] // 2 + remainder = reduced_code.shape[1] % 2 + + left = reduced_code[:, :half * 2:2] + right = reduced_code[:, 1:half * 2:2] + merged = left | right + + if remainder: + reduced_code = torch.cat([merged, reduced_code[:, -1:]], dim=1) + else: + reduced_code = merged + + reduced_code = reduced_code.squeeze(1) + + seglen_counts = _popcount_int32_tensor(reduced_code).to(torch.int32) - # 2) compute seglen (popcount per reduced_code) and seg (prefix sum) - seglen_counts = _popcount_int32_tensor(reduced_code.to(torch.int64)).to(torch.int32) # [num_blocks] - # seg: length num_blocks+1, seg[0]=0, seg[i+1]=cumsum(seglen_counts) up to i seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device) seg[0] = 0 if num_blocks > 0: @@ -336,30 +326,20 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2( total = int(seg[-1].item()) - # 3) scatter — produce valid_kernel_idx as concatenated ascending set-bit positions for each reduced_code row if total == 0: - valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device) - return valid_kernel_idx, seg + return torch.empty((0,), dtype=torch.int32, device=device), seg - max_val = int(reduced_code.max().item()) - V = max_val.bit_length() if max_val > 0 else 0 - # If you know V externally, pass it instead or set here explicitly. + V = int(reduced_code.max().item()).bit_length() if reduced_code.max() > 0 else 0 if V == 0: - # no bits set anywhere - valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device) - return valid_kernel_idx, seg + return torch.empty((0,), dtype=torch.int32, device=device), seg - # build mask of shape (num_blocks, V): True where bit is set - bit_pos = torch.arange(0, V, dtype=torch.int64, device=device) # [V] - # shifted = reduced_code[:, None] >> bit_pos[None, :] - shifted = reduced_code.to(torch.int64).unsqueeze(1) >> bit_pos.unsqueeze(0) - bits = (shifted & 1).to(torch.bool) # (num_blocks, V) + bit_pos = torch.arange(0, V, dtype=torch.int32, device=device) + shifted = reduced_code.unsqueeze(1) >> bit_pos.unsqueeze(0) + bits = (shifted & 1).to(torch.bool) positions = bit_pos.unsqueeze(0).expand(num_blocks, V) - - valid_positions = positions[bits] - valid_kernel_idx = valid_positions.to(torch.int32).contiguous() + valid_kernel_idx = positions[bits].to(torch.int32).contiguous() return valid_kernel_idx, seg @@ -425,35 +405,6 @@ def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache return out, neighbor -class Voxel: - def __init__( - self, - origin: list, - voxel_size: float, - coords: torch.Tensor = None, - attrs: torch.Tensor = None, - layout = None, - device: torch.device = None - ): - if layout is None: - layout = {} - self.origin = torch.tensor(origin, dtype=torch.float32, device=device) - self.voxel_size = voxel_size - self.coords = coords - self.attrs = attrs - self.layout = layout - self.device = device - - @property - def position(self): - return (self.coords + 0.5) * self.voxel_size + self.origin[None, :] - - def split_attrs(self): - return { - k: self.attrs[:, self.layout[k]] - for k in self.layout - } - class Mesh: def __init__(self, vertices, @@ -480,35 +431,3 @@ class Mesh: def cpu(self): return self.to('cpu') - -class MeshWithVoxel(Mesh, Voxel): - def __init__(self, - vertices: torch.Tensor, - faces: torch.Tensor, - origin: list, - voxel_size: float, - coords: torch.Tensor, - attrs: torch.Tensor, - voxel_shape: torch.Size, - layout: Dict = {}, - ): - self.vertices = vertices.float() - self.faces = faces.int() - self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device) - self.voxel_size = voxel_size - self.coords = coords - self.attrs = attrs - self.voxel_shape = voxel_shape - self.layout = layout - - def to(self, device, non_blocking=False): - return MeshWithVoxel( - self.vertices.to(device, non_blocking=non_blocking), - self.faces.to(device, non_blocking=non_blocking), - self.origin.tolist(), - self.voxel_size, - self.coords.to(device, non_blocking=non_blocking), - self.attrs.to(device, non_blocking=non_blocking), - self.voxel_shape, - self.layout, - ) diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 0b1975092..2a18c496a 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from fractions import Fraction from dataclasses import dataclass from typing import List, Any, Dict, Optional, overload, Union, Tuple -from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d +from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, sparse_submanifold_conv3d def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: @@ -210,6 +210,8 @@ class SparseResBlockC2S3d(nn.Module): def forward(self, x, subdiv = None): if self.pred_subdiv: + dtype = next(self.to_subdiv.parameters()).dtype + x = x.to(dtype) subdiv = self.to_subdiv(x) norm1 = self.norm1.to(torch.float32) norm2 = self.norm2.to(torch.float32) @@ -987,114 +989,7 @@ def convert_module_to_f16(l): for p in l.parameters(): p.data = p.data.half() - - -class SparseUnetVaeEncoder(nn.Module): - """ - Sparse Swin Transformer Unet VAE model. - """ - def __init__( - self, - in_channels: int, - model_channels: List[int], - latent_channels: int, - num_blocks: List[int], - block_type: List[str], - down_block_type: List[str], - block_args: List[Dict[str, Any]], - use_fp16: bool = False, - ): - super().__init__() - self.in_channels = in_channels - self.model_channels = model_channels - self.num_blocks = num_blocks - self.dtype = torch.float16 if use_fp16 else torch.float32 - - self.input_layer = SparseLinear(in_channels, model_channels[0]) - self.to_latent = SparseLinear(model_channels[-1], 2 * latent_channels) - - self.blocks = nn.ModuleList([]) - for i in range(len(num_blocks)): - self.blocks.append(nn.ModuleList([])) - for j in range(num_blocks[i]): - self.blocks[-1].append( - globals()[block_type[i]]( - model_channels[i], - **block_args[i], - ) - ) - if i < len(num_blocks) - 1: - self.blocks[-1].append( - globals()[down_block_type[i]]( - model_channels[i], - model_channels[i+1], - **block_args[i], - ) - ) - - @property - def device(self) -> torch.device: - return next(self.parameters()).device - - def forward(self, x: SparseTensor, sample_posterior=False, return_raw=False): - h = self.input_layer(x) - h = h.type(self.dtype) - for i, res in enumerate(self.blocks): - for j, block in enumerate(res): - h = block(h) - h = h.type(x.dtype) - h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) - h = self.to_latent(h) - - # Sample from the posterior distribution - mean, logvar = h.feats.chunk(2, dim=-1) - if sample_posterior: - std = torch.exp(0.5 * logvar) - z = mean + std * torch.randn_like(std) - else: - z = mean - z = h.replace(z) - - if return_raw: - return z, mean, logvar - else: - return z - - - -class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder): - def __init__( - self, - model_channels: List[int], - latent_channels: int, - num_blocks: List[int], - block_type: List[str], - down_block_type: List[str], - block_args: List[Dict[str, Any]], - use_fp16: bool = False, - ): - super().__init__( - 6, - model_channels, - latent_channels, - num_blocks, - block_type, - down_block_type, - block_args, - use_fp16, - ) - - def forward(self, vertices: SparseTensor, intersected: SparseTensor, sample_posterior=False, return_raw=False): - x = vertices.replace(torch.cat([ - vertices.feats - 0.5, - intersected.feats.float() - 0.5, - ], dim=1)) - return super().forward(x, sample_posterior, return_raw) - class SparseUnetVaeDecoder(nn.Module): - """ - Sparse Swin Transformer Unet VAE model. - """ def __init__( self, out_channels: int, @@ -1218,10 +1113,10 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder): N = coords.shape[0] # compute flat keys for all coords (prepend batch 0 same as original code) b = torch.zeros((N,), dtype=torch.long, device=device) - x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long() + x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32) W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) flat_keys = b * (W * H * D) + x * (H * D) + y * D + z - values = torch.arange(N, dtype=torch.long, device=device) + values = torch.arange(N, dtype=torch.int32, device=device) DEFAULT_VAL = 0xffffffff # sentinel used in original code return TorchHashMap(flat_keys, values, DEFAULT_VAL) @@ -1295,13 +1190,12 @@ def flexible_dual_grid_to_mesh( # Extract mesh N = dual_vertices.shape[0] - mesh_vertices = (coords.float() + dual_vertices) / (2 * N) - 0.5 if hashmap_builder is None: # build local TorchHashMap device = coords.device b = torch.zeros((N,), dtype=torch.long, device=device) - x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long() + x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32) W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) flat_keys = b * (W * H * D) + x * (H * D) + y * D + z values = torch.arange(N, dtype=torch.long, device=device) @@ -1316,9 +1210,9 @@ def flexible_dual_grid_to_mesh( M = connected_voxel.shape[0] # flatten connected voxel coords and lookup conn_flat_b = torch.zeros((M * 4,), dtype=torch.long, device=coords.device) - conn_x = connected_voxel.reshape(-1, 3)[:, 0].long() - conn_y = connected_voxel.reshape(-1, 3)[:, 1].long() - conn_z = connected_voxel.reshape(-1, 3)[:, 2].long() + conn_x = connected_voxel.reshape(-1, 3)[:, 0].to(torch.int32) + conn_y = connected_voxel.reshape(-1, 3)[:, 1].to(torch.int32) + conn_z = connected_voxel.reshape(-1, 3)[:, 2].to(torch.int32) W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) conn_flat = conn_flat_b * (W * H * D) + conn_x * (H * D) + conn_y * D + conn_z @@ -1526,17 +1420,18 @@ class Vae(nn.Module): channels=[512, 128, 32], ) + @torch.no_grad() def decode_shape_slat(self, slat, resolution: int): self.shape_dec.set_resolution(resolution) - device = comfy.model_management.get_torch_device() - self.shape_dec = self.shape_dec.to(device) return self.shape_dec(slat, return_subs=True) + @torch.no_grad() def decode_tex_slat(self, slat, subs): if self.txt_dec is None: raise ValueError("Checkpoint doesn't include texture model") return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5 + # shouldn't be called (placeholder) @torch.no_grad() def decode( self, @@ -1546,17 +1441,4 @@ class Vae(nn.Module): ): meshes, subs = self.decode_shape_slat(shape_slat, resolution) tex_voxels = self.decode_tex_slat(tex_slat, subs) - out_mesh = [] - for m, v in zip(meshes, tex_voxels): - out_mesh.append( - MeshWithVoxel( - m.vertices, m.faces, - origin = [-0.5, -0.5, -0.5], - voxel_size = 1 / resolution, - coords = v.coords[:, 1:], - attrs = v.feats, - voxel_shape = torch.Size([*v.shape, *v.spatial_shape]), - layout=self.pbr_attr_layout - ) - ) - return out_mesh + return tex_voxels diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 96510e916..e781d35e3 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,7 +1,7 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy.ldm.trellis2.vae import SparseTensor -from comfy.utils import ProgressBar +from comfy.utils import ProgressBar, lanczos import torch.nn.functional as TF import comfy.model_management from PIL import Image @@ -102,9 +102,7 @@ def run_conditioning(model, image, mask, include_1024 = True, background_color = cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb) def prepare_tensor(img, size): - resized = torch.nn.functional.interpolate( - img.unsqueeze(0), size=(size, size), mode='bicubic', align_corners=False - ) + resized = lanczos(img.unsqueeze(0), size, size) return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) model_internal.image_size = 512 @@ -148,10 +146,16 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, structure_output, vae, resolution): + + patcher = vae.patcher + device = comfy.model_management.get_torch_device() + comfy.model_management.load_model_gpu(patcher) + vae = vae.first_stage_model decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] + samples = samples.squeeze(-1).transpose(1, 2).to(device) std = shape_slat_normalization["std"].to(samples) mean = shape_slat_normalization["mean"].to(samples) samples = SparseTensor(feats = samples, coords=coords) @@ -179,10 +183,16 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, structure_output, vae, shape_subs): + + patcher = vae.patcher + device = comfy.model_management.get_torch_device() + comfy.model_management.load_model_gpu(patcher) + vae = vae.first_stage_model decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] + samples = samples.squeeze(-1).transpose(1, 2).to(device) std = tex_slat_normalization["std"].to(samples) mean = tex_slat_normalization["mean"].to(samples) samples = SparseTensor(feats = samples, coords=coords) From 7d444a4fcca545cf37d3bd42bc3e2d53f0994a3b Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 27 Feb 2026 22:22:07 +0200 Subject: [PATCH 39/59] resolution logit --- comfy/ldm/trellis2/model.py | 2 +- comfy_extras/nodes_trellis2.py | 23 +++++++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 45740faea..bd8309f2b 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -812,7 +812,7 @@ class Trellis2(nn.Module): raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure - timestep = timestep_reshift(timestep) + #timestep = timestep_reshift(timestep) orig_bsz = x.shape[0] if shape_rule: x = x[0].unsqueeze(0) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index e781d35e3..739233523 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -136,7 +136,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): IO.Latent.Input("samples"), IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), - IO.Int.Input("resolution", tooltip="Shape Generation Resolution"), + IO.Combo.Input("resolution", options=["512", "1024"], default="512") ], outputs=[ IO.Mesh.Output("mesh"), @@ -147,6 +147,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, structure_output, vae, resolution): + resolution = int(resolution) patcher = vae.patcher device = comfy.model_management.get_torch_device() comfy.model_management.load_model_gpu(patcher) @@ -154,14 +155,18 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): vae = vae.first_stage_model decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + samples = samples["samples"] - samples = samples.squeeze(-1).transpose(1, 2).to(device) + samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) std = shape_slat_normalization["std"].to(samples) mean = shape_slat_normalization["mean"].to(samples) samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean mesh, subs = vae.decode_shape_slat(samples, resolution) + faces = torch.stack([m.faces for m in mesh]) + verts = torch.stack([m.vertices for m in mesh]) + mesh = Types.MESH(vertices=verts, faces=faces) return IO.NodeOutput(mesh, subs) class VaeDecodeTextureTrellis(IO.ComfyNode): @@ -192,13 +197,16 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] - samples = samples.squeeze(-1).transpose(1, 2).to(device) + samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) std = tex_slat_normalization["std"].to(samples) mean = tex_slat_normalization["mean"].to(samples) samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean mesh = vae.decode_tex_slat(samples, shape_subs) + faces = torch.stack([m.faces for m in mesh]) + verts = torch.stack([m.vertices for m in mesh]) + mesh = Types.MESH(vertices=verts, faces=faces) return IO.NodeOutput(mesh) class VaeDecodeStructureTrellis2(IO.ComfyNode): @@ -210,6 +218,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): inputs=[ IO.Latent.Input("samples"), IO.Vae.Input("vae"), + IO.Combo.Input("resolution", options=["32", "64"], default="32") ], outputs=[ IO.Voxel.Output("structure_output"), @@ -217,7 +226,8 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, samples, vae): + def execute(cls, samples, vae, resolution): + resolution = int(resolution) vae = vae.first_stage_model decoder = vae.struct_dec load_device = comfy.model_management.get_torch_device() @@ -227,6 +237,11 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): samples = samples.to(load_device) decoded = decoder(samples)>0 decoder.to(offload_device) + current_res = decoded.shape[2] + + if current_res != resolution: + ratio = current_res // resolution + decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 out = Types.VOXEL(decoded.squeeze(1).float()) return IO.NodeOutput(out) From 44adb27782ea1df23ea43ca44fde808ac8b893d2 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 10 Mar 2026 22:27:10 +0200 Subject: [PATCH 40/59] working version --- comfy/ldm/trellis2/attention.py | 2 +- comfy/ldm/trellis2/model.py | 4 +- comfy_extras/nodes_trellis2.py | 430 +++++++++++++++++--------------- 3 files changed, 232 insertions(+), 204 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 0b9c12294..681666430 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -46,7 +46,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) if var_length: - return out.contiguous().transpose(1, 2).values() + return out.transpose(1, 2).values() if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index bd8309f2b..4bbfbff5f 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -767,8 +767,6 @@ class Trellis2(nn.Module): if embeds is None: raise ValueError("Trellis2.forward requires 'embeds' in kwargs") is_1024 = self.img2shape.resolution == 1024 - if is_1024: - context = embeds coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") if coords is not None: @@ -777,6 +775,8 @@ class Trellis2(nn.Module): else: mode = "structure_generation" not_struct_mode = False + if is_1024 and mode == "shape_generation": + context = embeds sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: timestep *= 1000.0 diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 739233523..23b2f72bb 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,9 +1,8 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy.ldm.trellis2.vae import SparseTensor -from comfy.utils import ProgressBar, lanczos -import torch.nn.functional as TF import comfy.model_management +import logging from PIL import Image import numpy as np import torch @@ -39,93 +38,6 @@ tex_slat_normalization = { ])[None] } -dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) -dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) - -def smart_crop_square(image, mask, margin_ratio=0.1, bg_color=(128, 128, 128)): - nz = torch.nonzero(mask[0] > 0.5) - if nz.shape[0] == 0: - C, H, W = image.shape - side = max(H, W) - canvas = torch.full((C, side, side), 0.5, device=image.device) # Gray - canvas[:, (side-H)//2:(side-H)//2+H, (side-W)//2:(side-W)//2+W] = image - return canvas - - y_min, x_min = nz.min(dim=0)[0] - y_max, x_max = nz.max(dim=0)[0] - - obj_w, obj_h = x_max - x_min, y_max - y_min - center_x, center_y = (x_min + x_max) / 2, (y_min + y_max) / 2 - - side = int(max(obj_w, obj_h) * (1 + margin_ratio * 2)) - half_side = side / 2 - - x1, y1 = int(center_x - half_side), int(center_y - half_side) - x2, y2 = x1 + side, y1 + side - - C, H, W = image.shape - canvas = torch.ones((C, side, side), device=image.device) - for c in range(C): - canvas[c] *= (bg_color[c] / 255.0) - - src_x1, src_y1 = max(0, x1), max(0, y1) - src_x2, src_y2 = min(W, x2), min(H, y2) - - dst_x1, dst_y1 = max(0, -x1), max(0, -y1) - dst_x2 = dst_x1 + (src_x2 - src_x1) - dst_y2 = dst_y1 + (src_y2 - src_y1) - - img_crop = image[:, src_y1:src_y2, src_x1:src_x2] - mask_crop = mask[0, src_y1:src_y2, src_x1:src_x2] - - bg_val = torch.tensor(bg_color, device=image.device, dtype=image.dtype).view(-1, 1, 1) / 255.0 - - masked_crop = img_crop * mask_crop + bg_val * (1.0 - mask_crop) - - canvas[:, dst_y1:dst_y2, dst_x1:dst_x2] = masked_crop - - return canvas - -def run_conditioning(model, image, mask, include_1024 = True, background_color = "black"): - model_internal = model.model - device = comfy.model_management.intermediate_device() - torch_device = comfy.model_management.get_torch_device() - - bg_colors = {"black": (0, 0, 0), "gray": (128, 128, 128), "white": (255, 255, 255)} - bg_rgb = bg_colors.get(background_color, (128, 128, 128)) - - img_t = image[0].movedim(-1, 0).to(torch_device).float() - mask_t = mask[0].to(torch_device).float() - if mask_t.ndim == 2: - mask_t = mask_t.unsqueeze(0) - - cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb) - - def prepare_tensor(img, size): - resized = lanczos(img.unsqueeze(0), size, size) - return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) - - model_internal.image_size = 512 - input_512 = prepare_tensor(cropped_img, 512) - cond_512 = model_internal(input_512)[0] - - cond_1024 = None - if include_1024: - model_internal.image_size = 1024 - input_1024 = prepare_tensor(cropped_img, 1024) - cond_1024 = model_internal(input_1024)[0] - - conditioning = { - 'cond_512': cond_512.to(device), - 'neg_cond': torch.zeros_like(cond_512).to(device), - } - if cond_1024 is not None: - conditioning['cond_1024'] = cond_1024.to(device) - - preprocessed_tensor = cropped_img.movedim(0, -1).unsqueeze(0).cpu() - - return conditioning, preprocessed_tensor - class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def define_schema(cls): @@ -245,6 +157,39 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): out = Types.VOXEL(decoded.squeeze(1).float()) return IO.NodeOutput(out) +dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) +dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + +def run_conditioning(model, cropped_img_tensor, include_1024=True): + model_internal = model.model + device = comfy.model_management.intermediate_device() + torch_device = comfy.model_management.get_torch_device() + + img_t = cropped_img_tensor.to(torch_device) + + def prepare_tensor(img, size): + resized = torch.nn.functional.interpolate(img, size=(size, size), mode='bicubic', align_corners=False).clamp(0.0, 1.0) + return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) + + model_internal.image_size = 512 + input_512 = prepare_tensor(img_t, 512) + cond_512 = model_internal(input_512)[0] + + cond_1024 = None + if include_1024: + model_internal.image_size = 1024 + input_1024 = prepare_tensor(img_t, 1024) + cond_1024 = model_internal(input_1024)[0] + + conditioning = { + 'cond_512': cond_512.to(device), + 'neg_cond': torch.zeros_like(cond_512).to(device), + } + if cond_1024 is not None: + conditioning['cond_1024'] = cond_1024.to(device) + + return conditioning + class Trellis2Conditioning(IO.ComfyNode): @classmethod def define_schema(cls): @@ -268,22 +213,60 @@ class Trellis2Conditioning(IO.ComfyNode): if image.ndim == 4: image = image[0] + if mask.ndim == 3: + mask = mask[0] - # TODO - image = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - image = Image.fromarray(image) - max_size = max(image.size) - scale = min(1, 1024 / max_size) - if scale < 1: - image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS) - new_h, new_w = int(mask.shape[-2] * scale), int(mask.shape[-1] * scale) - mask = TF.interpolate(mask.unsqueeze(0).float(), size=(new_h, new_w), mode='nearest').squeeze(0) + img_np = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + mask_np = (mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255 + pil_img = Image.fromarray(img_np) + pil_mask = Image.fromarray(mask_np) - # could make 1024 an option - conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) - embeds = conditioning["cond_1024"] # should add that + max_size = max(pil_img.size) + scale = min(1.0, 1024 / max_size) + if scale < 1.0: + new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale) + pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS) + pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST) + + rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8) + rgba_np[:, :, :3] = np.array(pil_img) + rgba_np[:, :, 3] = np.array(pil_mask) + + alpha = rgba_np[:, :, 3] + bbox_coords = np.argwhere(alpha > 0.8 * 255) + + if len(bbox_coords) > 0: + y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1]) + y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1]) + + center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0 + size = max(y_max - y_min, x_max - x_min) + + crop_x1 = int(center_x - size // 2) + crop_y1 = int(center_y - size // 2) + crop_x2 = int(center_x + size // 2) + crop_y2 = int(center_y + size // 2) + + rgba_pil = Image.fromarray(rgba_np, 'RGBA') + cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)) + cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 + else: + logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") + cropped_np = rgba_np.astype(np.float32) / 255.0 + + bg_colors = {"black": [0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} + bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32) + + fg = cropped_np[:, :, :3] + alpha_float = cropped_np[:, :, 3:4] + composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) + + cropped_img_tensor = torch.from_numpy(composite_np).movedim(-1, 0).unsqueeze(0).float() + + conditioning = run_conditioning(clip_vision_model, cropped_img_tensor, include_1024=True) + + embeds = conditioning["cond_1024"] positive = [[conditioning["cond_512"], {"embeds": embeds}]] negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]] return IO.NodeOutput(positive, negative) @@ -417,118 +400,168 @@ def simplify_fn(vertices, faces, target=100000): return final_vertices, final_faces -def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): +def fill_holes_fn(vertices, faces, max_perimeter=0.03): is_batched = vertices.ndim == 3 if is_batched: - batch_size = vertices.shape[0] - if batch_size > 1: - v_out, f_out = [], [] - for i in range(batch_size): - v, f = fill_holes_fn(vertices[i], faces[i], max_hole_perimeter) - v_out.append(v) - f_out.append(f) - return torch.stack(v_out), torch.stack(f_out) - - vertices = vertices.squeeze(0) - faces = faces.squeeze(0) + v_list, f_list = [],[] + for i in range(vertices.shape[0]): + v_i, f_i = fill_holes_fn(vertices[i], faces[i], max_perimeter) + v_list.append(v_i) + f_list.append(f_i) + return torch.stack(v_list), torch.stack(f_list) device = vertices.device - orig_vertices = vertices - orig_faces = faces + v = vertices + f = faces - edges = torch.cat([ - faces[:, [0, 1]], - faces[:, [1, 2]], - faces[:, [2, 0]] - ], dim=0) + if f.shape[0] == 0: + return v, f + edges = torch.cat([f[:, [0, 1]], f[:, [1, 2]], f[:, [2, 0]]], dim=0) edges_sorted, _ = torch.sort(edges, dim=1) - unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True) + + max_v = v.shape[0] + packed_undirected = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long() + + unique_packed, counts = torch.unique(packed_undirected, return_counts=True) boundary_mask = counts == 1 - boundary_edges_sorted = unique_edges[boundary_mask] + boundary_packed = unique_packed[boundary_mask] - if boundary_edges_sorted.shape[0] == 0: - if is_batched: - return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) - return orig_vertices, orig_faces + if boundary_packed.numel() == 0: + return v, f - max_idx = vertices.shape[0] - - packed_edges_all = torch.sort(edges, dim=1).values - packed_edges_all = packed_edges_all[:, 0] * max_idx + packed_edges_all[:, 1] - - packed_boundary = boundary_edges_sorted[:, 0] * max_idx + boundary_edges_sorted[:, 1] - - is_boundary_edge = torch.isin(packed_edges_all, packed_boundary) - active_boundary_edges = edges[is_boundary_edge] + packed_directed_sorted = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long() + is_boundary = torch.isin(packed_directed_sorted, boundary_packed) + boundary_edges_directed = edges[is_boundary] adj = {} - edges_np = active_boundary_edges.cpu().numpy() - for u, v in edges_np: - adj[u] = v + in_deg = {} + out_deg = {} + + edges_list = boundary_edges_directed.tolist() + for u, v_idx in edges_list: + if u not in adj: adj[u] = [] + adj[u].append(v_idx) + out_deg[u] = out_deg.get(u, 0) + 1 + in_deg[v_idx] = in_deg.get(v_idx, 0) + 1 + + manifold_nodes = set() + for node in set(list(in_deg.keys()) + list(out_deg.keys())): + if in_deg.get(node, 0) == 1 and out_deg.get(node, 0) == 1: + manifold_nodes.add(node) + + loops =[] + visited_nodes = set() - loops = [] - visited_edges = set() - processed_nodes = set() for start_node in list(adj.keys()): - if start_node in processed_nodes: + if start_node not in manifold_nodes or start_node in visited_nodes: continue - current_loop, curr = [], start_node - while curr in adj: - next_node = adj[curr] - if (curr, next_node) in visited_edges: - break - visited_edges.add((curr, next_node)) - processed_nodes.add(curr) + + curr = start_node + current_loop =[] + + while True: current_loop.append(curr) + visited_nodes.add(curr) + + next_node = adj[curr][0] + + if next_node == start_node: + if len(current_loop) >= 3: + loops.append(current_loop) + break + + if next_node not in manifold_nodes or next_node in visited_nodes: + break + curr = next_node - if curr == start_node: - loops.append(current_loop) - break - if len(current_loop) > len(edges_np): + + if len(current_loop) > len(edges_list): break - if not loops: - if is_batched: - return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) - return orig_vertices, orig_faces + new_faces =[] + new_verts = [] + curr_v_idx = v.shape[0] - new_faces = [] - v_offset = vertices.shape[0] - valid_new_verts = [] + for loop in loops: + loop_indices = torch.tensor(loop, device=device, dtype=torch.long) + loop_points = v[loop_indices] - for loop_indices in loops: - if len(loop_indices) < 3: - continue - loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) - loop_verts = vertices[loop_tensor] - diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) - perimeter = torch.norm(diffs, dim=1).sum() + # Calculate perimeter + p1 = loop_points + p2 = torch.roll(loop_points, shifts=-1, dims=0) + perimeter = torch.norm(p1 - p2, dim=1).sum().item() - if perimeter > max_hole_perimeter: - continue + if perimeter <= max_perimeter: + centroid = loop_points.mean(dim=0) + new_verts.append(centroid) + center_idx = curr_v_idx + curr_v_idx += 1 - center = loop_verts.mean(dim=0) - valid_new_verts.append(center) - c_idx = v_offset - v_offset += 1 + for i in range(len(loop)): + u_idx = loop[i] + v_next_idx = loop[(i + 1) % len(loop)] + new_faces.append([u_idx, v_next_idx, center_idx]) - num_v = len(loop_indices) - for i in range(num_v): - v_curr, v_next = loop_indices[i], loop_indices[(i + 1) % num_v] - new_faces.append([v_curr, v_next, c_idx]) + if new_faces: + v = torch.cat([v, torch.stack(new_verts)], dim=0) + f = torch.cat([f, torch.tensor(new_faces, device=device, dtype=torch.long)], dim=0) - if len(valid_new_verts) > 0: - added_vertices = torch.stack(valid_new_verts, dim=0) - added_faces = torch.tensor(new_faces, dtype=torch.long, device=device) - vertices = torch.cat([orig_vertices, added_vertices], dim=0) - faces = torch.cat([orig_faces, added_faces], dim=0) - else: - vertices, faces = orig_vertices, orig_faces + return v, f +def merge_duplicate_vertices(vertices, faces, tolerance=1e-5): + is_batched = vertices.ndim == 3 if is_batched: - return vertices.unsqueeze(0), faces.unsqueeze(0) + v_list, f_list = [],[] + for i in range(vertices.shape[0]): + v_i, f_i = merge_duplicate_vertices(vertices[i], faces[i], tolerance) + v_list.append(v_i) + f_list.append(f_i) + return torch.stack(v_list), torch.stack(f_list) + v_min = vertices.min(dim=0, keepdim=True)[0] + v_quant = ((vertices - v_min) / tolerance).round().long() + + unique_quant, inverse_indices = torch.unique(v_quant, dim=0, return_inverse=True) + + new_vertices = torch.zeros((unique_quant.shape[0], 3), dtype=vertices.dtype, device=vertices.device) + new_vertices.index_copy_(0, inverse_indices, vertices) + + new_faces = inverse_indices[faces.long()] + + valid = (new_faces[:, 0] != new_faces[:, 1]) & \ + (new_faces[:, 1] != new_faces[:, 2]) & \ + (new_faces[:, 2] != new_faces[:, 0]) + + return new_vertices, new_faces[valid] + +def fix_normals(vertices, faces): + is_batched = vertices.ndim == 3 + if is_batched: + v_list, f_list = [], [] + for i in range(vertices.shape[0]): + v_i, f_i = fix_normals(vertices[i], faces[i]) + v_list.append(v_i) + f_list.append(f_i) + return torch.stack(v_list), torch.stack(f_list) + + if faces.shape[0] == 0: + return vertices, faces + + center = vertices.mean(0) + v0 = vertices[faces[:, 0].long()] + v1 = vertices[faces[:, 1].long()] + v2 = vertices[faces[:, 2].long()] + + normals = torch.cross(v1 - v0, v2 - v0, dim=1) + + face_centers = (v0 + v1 + v2) / 3.0 + dir_from_center = face_centers - center + + dot = (normals * dir_from_center).sum(1) + flip_mask = dot < 0 + + faces[flip_mask] = faces[flip_mask][:, [0, 2, 1]] return vertices, faces class PostProcessMesh(IO.ComfyNode): @@ -539,36 +572,31 @@ class PostProcessMesh(IO.ComfyNode): category="latent/3d", inputs=[ IO.Mesh.Input("mesh"), - IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), # max? - IO.Float.Input("fill_holes_perimeter", default=0.003, min=0.0, step=0.0001) + IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), + IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0, step=0.0001) ], outputs=[ IO.Mesh.Output("output_mesh"), ] ) + @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): - bar = ProgressBar(2) mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces - if fill_holes_perimeter != 0.0: - verts, faces = fill_holes_fn(verts, faces, max_hole_perimeter=fill_holes_perimeter) - bar.update(1) - else: - bar.update(1) + verts, faces = merge_duplicate_vertices(verts, faces, tolerance=1e-5) - if simplify != 0: - verts, faces = simplify_fn(verts, faces, simplify) - bar.update(1) - else: - bar.update(1) + if fill_holes_perimeter > 0: + verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter) - # potentially adding laplacian smoothing + if simplify > 0 and faces.shape[0] > simplify: + verts, faces = simplify_fn(verts, faces, target=simplify) + + verts, faces = fix_normals(verts, faces) mesh.vertices = verts mesh.faces = faces - return IO.NodeOutput(mesh) class Trellis2Extension(ComfyExtension): From 011f624dd548540e92cb8ae08671e36ae3532b6a Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 11 Mar 2026 20:11:58 +0200 Subject: [PATCH 41/59] post-process rewrite + light texture model work --- comfy/ldm/trellis2/model.py | 4 + comfy_extras/nodes_trellis2.py | 214 ++++++++++++--------------------- 2 files changed, 84 insertions(+), 134 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 4bbfbff5f..651d516b6 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -810,6 +810,10 @@ class Trellis2(nn.Module): elif mode == "texture_generation": if self.shape2txt is None: raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") + slat = transformer_options.get("shape_slat") + if slat is None: + raise ValueError("shape_slat can't be None") + x = sparse_cat([x, slat]) out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure #timestep = timestep_reshift(timestep) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 23b2f72bb..0b94a2d0a 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -38,6 +38,13 @@ tex_slat_normalization = { ])[None] } +def shape_norm(shape_latent, coords): + std = shape_slat_normalization["std"].to(shape_latent) + mean = shape_slat_normalization["mean"].to(shape_latent) + samples = SparseTensor(feats = shape_latent, coords=coords) + samples = samples * std + mean + return samples + class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def define_schema(cls): @@ -70,10 +77,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): samples = samples["samples"] samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) - std = shape_slat_normalization["std"].to(samples) - mean = shape_slat_normalization["mean"].to(samples) - samples = SparseTensor(feats = samples, coords=coords) - samples = samples * std + mean + samples = shape_norm(samples, coords) mesh, subs = vae.decode_shape_slat(samples, resolution) faces = torch.stack([m.faces for m in mesh]) @@ -313,6 +317,8 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): category="latent/3d", inputs=[ IO.Voxel.Input("structure_output"), + IO.Latent.Input("shape_latent"), + IO.Model.Input("model") ], outputs=[ IO.Latent.Output(), @@ -321,11 +327,15 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_output, model): + def execute(cls, structure_output, shape_latent, model): # TODO decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 + + shape_latent = shape_latent["samples"] + shape_latent = shape_norm(shape_latent, coords) + latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) model = model.clone() model.model_options = model.model_options.copy() @@ -336,6 +346,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"]["coords"] = coords model.model_options["transformer_options"]["generation_mode"] = "shape_generation" + model.model_options["transformer_options"]["shape_slat"] = shape_latent return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) @@ -360,25 +371,34 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): return IO.NodeOutput({"samples": latent, "type": "trellis2"}) def simplify_fn(vertices, faces, target=100000): + is_batched = vertices.ndim == 3 + if is_batched: + v_list, f_list = [], [] + for i in range(vertices.shape[0]): + v_i, f_i = simplify_fn(vertices[i], faces[i], target) + v_list.append(v_i) + f_list.append(f_i) + return torch.stack(v_list), torch.stack(f_list) - if vertices.shape[0] <= target: + if faces.shape[0] <= target: return vertices, faces - min_feat = vertices.min(dim=0)[0] - max_feat = vertices.max(dim=0)[0] - extent = (max_feat - min_feat).max() + device = vertices.device + target_v = target / 2.0 - grid_resolution = int(torch.sqrt(torch.tensor(target)).item() * 1.5) - voxel_size = extent / grid_resolution + min_v = vertices.min(dim=0)[0] + max_v = vertices.max(dim=0)[0] + extent = max_v - min_v - quantized_coords = ((vertices - min_feat) / voxel_size).long() + volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8) + cell_size = (volume / target_v) ** (1/3.0) - unique_coords, inverse_indices = torch.unique(quantized_coords, dim=0, return_inverse=True) + quantized = ((vertices - min_v) / cell_size).round().long() + unique_coords, inverse_indices = torch.unique(quantized, dim=0, return_inverse=True) + num_cells = unique_coords.shape[0] - num_new_verts = unique_coords.shape[0] - new_vertices = torch.zeros((num_new_verts, 3), dtype=vertices.dtype, device=vertices.device) - - counts = torch.zeros((num_new_verts, 1), dtype=vertices.dtype, device=vertices.device) + new_vertices = torch.zeros((num_cells, 3), dtype=vertices.dtype, device=device) + counts = torch.zeros((num_cells, 1), dtype=vertices.dtype, device=device) new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices) counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1])) @@ -387,11 +407,9 @@ def simplify_fn(vertices, faces, target=100000): new_faces = inverse_indices[faces] - v0 = new_faces[:, 0] - v1 = new_faces[:, 1] - v2 = new_faces[:, 2] - - valid_mask = (v0 != v1) & (v1 != v2) & (v2 != v0) + valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \ + (new_faces[:, 1] != new_faces[:, 2]) & \ + (new_faces[:, 2] != new_faces[:, 0]) new_faces = new_faces[valid_mask] unique_face_indices, inv_face = torch.unique(new_faces.reshape(-1), return_inverse=True) @@ -414,7 +432,7 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): v = vertices f = faces - if f.shape[0] == 0: + if f.numel() == 0: return v, f edges = torch.cat([f[:, [0, 1]], f[:, [1, 2]], f[:, [2, 0]]], dim=0) @@ -424,145 +442,75 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): packed_undirected = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long() unique_packed, counts = torch.unique(packed_undirected, return_counts=True) - boundary_mask = counts == 1 - boundary_packed = unique_packed[boundary_mask] + boundary_packed = unique_packed[counts == 1] if boundary_packed.numel() == 0: return v, f - packed_directed_sorted = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long() + packed_directed_sorted = edges[:, 0].min(edges[:, 1]).long() * max_v + edges[:, 0].max(edges[:, 1]).long() is_boundary = torch.isin(packed_directed_sorted, boundary_packed) - boundary_edges_directed = edges[is_boundary] + b_edges = edges[is_boundary] - adj = {} - in_deg = {} - out_deg = {} - - edges_list = boundary_edges_directed.tolist() - for u, v_idx in edges_list: - if u not in adj: adj[u] = [] - adj[u].append(v_idx) - out_deg[u] = out_deg.get(u, 0) + 1 - in_deg[v_idx] = in_deg.get(v_idx, 0) + 1 - - manifold_nodes = set() - for node in set(list(in_deg.keys()) + list(out_deg.keys())): - if in_deg.get(node, 0) == 1 and out_deg.get(node, 0) == 1: - manifold_nodes.add(node) + adj = {u.item(): v_idx.item() for u, v_idx in b_edges} loops =[] - visited_nodes = set() + visited = set() - for start_node in list(adj.keys()): - if start_node not in manifold_nodes or start_node in visited_nodes: + for start_node in adj.keys(): + if start_node in visited: continue curr = start_node - current_loop =[] + loop = [] - while True: - current_loop.append(curr) - visited_nodes.add(curr) + while curr not in visited: + visited.add(curr) + loop.append(curr) + curr = adj.get(curr, -1) - next_node = adj[curr][0] - - if next_node == start_node: - if len(current_loop) >= 3: - loops.append(current_loop) + if curr == -1: + loop = [] + break + if curr == start_node: + loops.append(loop) break - if next_node not in manifold_nodes or next_node in visited_nodes: - break - - curr = next_node - - if len(current_loop) > len(edges_list): - break - - new_faces =[] - new_verts = [] - curr_v_idx = v.shape[0] + new_verts =[] + new_faces = [] + v_idx = v.shape[0] for loop in loops: - loop_indices = torch.tensor(loop, device=device, dtype=torch.long) - loop_points = v[loop_indices] + loop_t = torch.tensor(loop, device=device, dtype=torch.long) + loop_v = v[loop_t] - # Calculate perimeter - p1 = loop_points - p2 = torch.roll(loop_points, shifts=-1, dims=0) - perimeter = torch.norm(p1 - p2, dim=1).sum().item() + diffs = loop_v - torch.roll(loop_v, shifts=-1, dims=0) + perimeter = torch.norm(diffs, dim=1).sum().item() if perimeter <= max_perimeter: - centroid = loop_points.mean(dim=0) - new_verts.append(centroid) - center_idx = curr_v_idx - curr_v_idx += 1 + new_verts.append(loop_v.mean(dim=0)) for i in range(len(loop)): - u_idx = loop[i] - v_next_idx = loop[(i + 1) % len(loop)] - new_faces.append([u_idx, v_next_idx, center_idx]) + new_faces.append([loop[i], loop[(i + 1) % len(loop)], v_idx]) + v_idx += 1 - if new_faces: + if new_verts: v = torch.cat([v, torch.stack(new_verts)], dim=0) f = torch.cat([f, torch.tensor(new_faces, device=device, dtype=torch.long)], dim=0) return v, f -def merge_duplicate_vertices(vertices, faces, tolerance=1e-5): +def make_double_sided(vertices, faces): is_batched = vertices.ndim == 3 if is_batched: - v_list, f_list = [],[] - for i in range(vertices.shape[0]): - v_i, f_i = merge_duplicate_vertices(vertices[i], faces[i], tolerance) - v_list.append(v_i) - f_list.append(f_i) - return torch.stack(v_list), torch.stack(f_list) + f_list =[] + for i in range(faces.shape[0]): + f_inv = faces[i][:,[0, 2, 1]] + f_list.append(torch.cat([faces[i], f_inv], dim=0)) + return vertices, torch.stack(f_list) - v_min = vertices.min(dim=0, keepdim=True)[0] - v_quant = ((vertices - v_min) / tolerance).round().long() - - unique_quant, inverse_indices = torch.unique(v_quant, dim=0, return_inverse=True) - - new_vertices = torch.zeros((unique_quant.shape[0], 3), dtype=vertices.dtype, device=vertices.device) - new_vertices.index_copy_(0, inverse_indices, vertices) - - new_faces = inverse_indices[faces.long()] - - valid = (new_faces[:, 0] != new_faces[:, 1]) & \ - (new_faces[:, 1] != new_faces[:, 2]) & \ - (new_faces[:, 2] != new_faces[:, 0]) - - return new_vertices, new_faces[valid] - -def fix_normals(vertices, faces): - is_batched = vertices.ndim == 3 - if is_batched: - v_list, f_list = [], [] - for i in range(vertices.shape[0]): - v_i, f_i = fix_normals(vertices[i], faces[i]) - v_list.append(v_i) - f_list.append(f_i) - return torch.stack(v_list), torch.stack(f_list) - - if faces.shape[0] == 0: - return vertices, faces - - center = vertices.mean(0) - v0 = vertices[faces[:, 0].long()] - v1 = vertices[faces[:, 1].long()] - v2 = vertices[faces[:, 2].long()] - - normals = torch.cross(v1 - v0, v2 - v0, dim=1) - - face_centers = (v0 + v1 + v2) / 3.0 - dir_from_center = face_centers - center - - dot = (normals * dir_from_center).sum(1) - flip_mask = dot < 0 - - faces[flip_mask] = faces[flip_mask][:, [0, 2, 1]] - return vertices, faces + faces_inv = faces[:, [0, 2, 1]] + faces_double = torch.cat([faces, faces_inv], dim=0) + return vertices, faces_double class PostProcessMesh(IO.ComfyNode): @classmethod @@ -572,7 +520,7 @@ class PostProcessMesh(IO.ComfyNode): category="latent/3d", inputs=[ IO.Mesh.Input("mesh"), - IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), + IO.Int.Input("simplify", default=1_000_000, min=0, max=50_000_000), IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0, step=0.0001) ], outputs=[ @@ -585,15 +533,13 @@ class PostProcessMesh(IO.ComfyNode): mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces - verts, faces = merge_duplicate_vertices(verts, faces, tolerance=1e-5) - if fill_holes_perimeter > 0: verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter) if simplify > 0 and faces.shape[0] > simplify: verts, faces = simplify_fn(verts, faces, target=simplify) - verts, faces = fix_normals(verts, faces) + verts, faces = make_double_sided(verts, faces) mesh.vertices = verts mesh.faces = faces From 2d904b28da9631a756784b9bd54c4b46b8290522 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 11 Mar 2026 22:50:17 +0200 Subject: [PATCH 42/59] upscale node + simple node simplification --- comfy_extras/nodes_trellis2.py | 78 ++++++++++++++++++++++++++++------ 1 file changed, 65 insertions(+), 13 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 0b94a2d0a..86f08f8bd 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -53,7 +53,6 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): category="latent/3d", inputs=[ IO.Latent.Input("samples"), - IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), IO.Combo.Input("resolution", options=["512", "1024"], default="512") ], @@ -64,7 +63,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples, structure_output, vae, resolution): + def execute(cls, samples, vae, resolution): resolution = int(resolution) patcher = vae.patcher @@ -72,8 +71,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): comfy.model_management.load_model_gpu(patcher) vae = vae.first_stage_model - decoded = structure_output.data.unsqueeze(1) - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + coords = samples["coords"] samples = samples["samples"] samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) @@ -93,7 +91,6 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): category="latent/3d", inputs=[ IO.Latent.Input("samples"), - IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), IO.AnyType.Input("shape_subs"), ], @@ -103,15 +100,15 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples, structure_output, vae, shape_subs): + def execute(cls, samples, vae, shape_subs): patcher = vae.patcher device = comfy.model_management.get_torch_device() comfy.model_management.load_model_gpu(patcher) vae = vae.first_stage_model - decoded = structure_output.data.unsqueeze(1) - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + coords = samples["coords"] + samples = samples["samples"] samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) std = tex_slat_normalization["std"].to(samples) @@ -161,6 +158,56 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): out = Types.VOXEL(decoded.squeeze(1).float()) return IO.NodeOutput(out) +class Trellis2UpsampleCascade(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Trellis2UpsampleCascade", + category="latent/3d", + inputs=[ + IO.Latent.Input("shape_latent_512"), + IO.Vae.Input("vae"), + IO.Combo.Input("target_resolution", options=["1024", "1536"], default="1024"), + IO.Int.Input("max_tokens", default=49152, min=1024, max=100000) + ], + outputs=[ + IO.AnyType.Output("hr_coords"), + ] + ) + + @classmethod + def execute(cls, shape_latent_512, vae, target_resolution, max_tokens): + device = comfy.model_management.get_torch_device() + + feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) + coords_512 = shape_latent_512["coords"].to(device) + + slat = shape_norm(feats, coords_512) + + decoder = vae.first_stage_model.shape_dec + decoder.to(device) + + slat.feats = slat.feats.to(next(decoder.parameters()).dtype) + hr_coords = decoder.upsample(slat, upsample_times=4) + decoder.cpu() + + lr_resolution = 512 + hr_resolution = int(target_resolution) + + while True: + quant_coords = torch.cat([ + hr_coords[:, :1], + ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), + ], dim=1) + final_coords = quant_coords.unique(dim=0) + num_tokens = final_coords.shape[0] + + if num_tokens < max_tokens or hr_resolution <= 1024: + break + hr_resolution -= 128 + + return IO.NodeOutput(final_coords.cpu()) + dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) @@ -282,7 +329,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): node_id="EmptyShapeLatentTrellis2", category="latent/3d", inputs=[ - IO.Voxel.Input("structure_output"), + IO.AnyType.Input("structure_or_coords"), IO.Model.Input("model") ], outputs=[ @@ -292,9 +339,13 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_output, model): - decoded = structure_output.data.unsqueeze(1) - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + def execute(cls, structure_or_coords, model): + # to accept the upscaled coords + if hasattr(structure_or_coords, "data"): + decoded = structure_or_coords.data.unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + else: + coords = structure_or_coords in_channels = 32 # image like format latent = torch.randn(1, in_channels, coords.shape[0], 1) @@ -307,7 +358,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"]["coords"] = coords model.model_options["transformer_options"]["generation_mode"] = "shape_generation" - return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) + return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -556,6 +607,7 @@ class Trellis2Extension(ComfyExtension): VaeDecodeTextureTrellis, VaeDecodeShapeTrellis, VaeDecodeStructureTrellis2, + Trellis2UpsampleCascade, PostProcessMesh ] From 5d2548822c4519b8fb22b454c15114f7ba0ea26d Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 20 Mar 2026 02:36:01 +0200 Subject: [PATCH 43/59] . --- comfy/ldm/trellis2/model.py | 1 + comfy_extras/nodes_trellis2.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 651d516b6..7a3e387c3 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -756,6 +756,7 @@ class Trellis2(nn.Module): self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) if init_txt_model: self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) + self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args) args.pop("out_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) self.guidance_interval = [0.6, 1.0] diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 86f08f8bd..409d2d23c 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -54,7 +54,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): inputs=[ IO.Latent.Input("samples"), IO.Vae.Input("vae"), - IO.Combo.Input("resolution", options=["512", "1024"], default="512") + IO.Combo.Input("resolution", options=["512", "1024"], default="1024") ], outputs=[ IO.Mesh.Output("mesh"), @@ -116,7 +116,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean - mesh = vae.decode_tex_slat(samples, shape_subs) + mesh = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5 faces = torch.stack([m.faces for m in mesh]) verts = torch.stack([m.vertices for m in mesh]) mesh = Types.MESH(vertices=verts, faces=faces) @@ -541,7 +541,7 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): new_verts.append(loop_v.mean(dim=0)) for i in range(len(loop)): - new_faces.append([loop[i], loop[(i + 1) % len(loop)], v_idx]) + new_faces.append([loop[(i + 1) % len(loop)], loop[i], v_idx]) v_idx += 1 if new_verts: From def8947e75b406fa87df78f9c6f4b9d0929eb6d8 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 20 Mar 2026 18:37:11 +0200 Subject: [PATCH 44/59] shape working --- comfy/ldm/trellis2/model.py | 50 +++++++++++++++++++++++++--------- comfy_extras/nodes_trellis2.py | 22 +++++++++++---- 2 files changed, 53 insertions(+), 19 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 7a3e387c3..8a0c6d8b6 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -770,14 +770,20 @@ class Trellis2(nn.Module): is_1024 = self.img2shape.resolution == 1024 coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") + is_512_run = False + if mode == "shape_generation_512": + is_512_run = True + mode = "shape_generation" if coords is not None: x = x.squeeze(-1).transpose(1, 2) not_struct_mode = True else: mode = "structure_generation" not_struct_mode = False - if is_1024 and mode == "shape_generation": + + if is_1024 and mode == "shape_generation" and not is_512_run: context = embeds + sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: timestep *= 1000.0 @@ -786,12 +792,24 @@ class Trellis2(nn.Module): txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] if not_struct_mode: - B, N, C = x.shape + orig_bsz = x.shape[0] + rule = txt_rule if mode == "texture_generation" else shape_rule - if mode == "shape_generation": - feats_flat = x.reshape(-1, C) + if rule and orig_bsz > 1: + x_eval = x[1].unsqueeze(0) + t_eval = timestep[1].unsqueeze(0) if timestep.shape[0] > 1 else timestep + c_eval = cond + else: + x_eval = x + t_eval = timestep + c_eval = context - # 3. inflate coords [N, 4] -> [B*N, 4] + B, N, C = x_eval.shape + + if mode in ["shape_generation", "texture_generation"]: + feats_flat = x_eval.reshape(-1, C) + + # inflate coords [N, 4] -> [B*N, 4] coords_list = [] for i in range(B): c = coords.clone() @@ -799,23 +817,27 @@ class Trellis2(nn.Module): coords_list.append(c) batched_coords = torch.cat(coords_list, dim=0) - else: # TODO: texture - # may remove the else if texture doesn't require special handling + else: batched_coords = coords - feats_flat = x - x = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) + feats_flat = x_eval + + x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) if mode == "shape_generation": # TODO - out = self.img2shape(x, timestep, context) + if is_512_run: + out = self.img2shape_512(x_st, t_eval, c_eval) + else: + out = self.img2shape(x_st, t_eval, c_eval) elif mode == "texture_generation": if self.shape2txt is None: raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") slat = transformer_options.get("shape_slat") if slat is None: raise ValueError("shape_slat can't be None") - x = sparse_cat([x, slat]) - out = self.shape2txt(x, timestep, context if not txt_rule else cond) + slat.feats = slat.feats.repeat(B, 1) + x_st = sparse_cat([x_st, slat]) + out = self.shape2txt(x_st, t_eval, c_eval) else: # structure #timestep = timestep_reshift(timestep) orig_bsz = x.shape[0] @@ -828,6 +850,8 @@ class Trellis2(nn.Module): if not_struct_mode: out = out.feats - if mode == "shape_generation": + if not_struct_mode: out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) + if rule and orig_bsz > 1: + out = out.repeat(orig_bsz, 1, 1, 1) return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 409d2d23c..cba6b3241 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -178,6 +178,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode): @classmethod def execute(cls, shape_latent_512, vae, target_resolution, max_tokens): device = comfy.model_management.get_torch_device() + comfy.model_management.load_model_gpu(vae.patcher) feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) coords_512 = shape_latent_512["coords"].to(device) @@ -185,11 +186,9 @@ class Trellis2UpsampleCascade(IO.ComfyNode): slat = shape_norm(feats, coords_512) decoder = vae.first_stage_model.shape_dec - decoder.to(device) slat.feats = slat.feats.to(next(decoder.parameters()).dtype) hr_coords = decoder.upsample(slat, upsample_times=4) - decoder.cpu() lr_resolution = 512 hr_resolution = int(target_resolution) @@ -206,7 +205,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode): break hr_resolution -= 128 - return IO.NodeOutput(final_coords.cpu()) + return IO.NodeOutput(final_coords,) dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) @@ -341,11 +340,19 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_or_coords, model): # to accept the upscaled coords - if hasattr(structure_or_coords, "data"): + is_512_pass = False + + if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: decoded = structure_or_coords.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + is_512_pass = True + + elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: + coords = structure_or_coords.int() + is_512_pass = False + else: - coords = structure_or_coords + raise ValueError(f"Invalid input to EmptyShapeLatent: {type(structure_or_coords)}") in_channels = 32 # image like format latent = torch.randn(1, in_channels, coords.shape[0], 1) @@ -357,7 +364,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords - model.model_options["transformer_options"]["generation_mode"] = "shape_generation" + if is_512_pass: + model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512" + else: + model.model_options["transformer_options"]["generation_mode"] = "shape_generation" return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): From 56e52e5d03f52c407cf529c8b211a1636a3ed221 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 24 Mar 2026 02:44:03 +0200 Subject: [PATCH 45/59] work on txt gen --- comfy/ldm/trellis2/model.py | 15 ++++++++------- comfy_extras/nodes_trellis2.py | 33 ++++++++++++++++++--------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 8a0c6d8b6..34aeba3e1 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -754,6 +754,7 @@ class Trellis2(nn.Module): "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations } self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) + self.shape2txt = None if init_txt_model: self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args) @@ -835,11 +836,12 @@ class Trellis2(nn.Module): slat = transformer_options.get("shape_slat") if slat is None: raise ValueError("shape_slat can't be None") - slat.feats = slat.feats.repeat(B, 1) - x_st = sparse_cat([x_st, slat]) + + base_slat_feats = slat.feats[:N] + slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device) + x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1)) out = self.shape2txt(x_st, t_eval, c_eval) else: # structure - #timestep = timestep_reshift(timestep) orig_bsz = x.shape[0] if shape_rule: x = x[0].unsqueeze(0) @@ -850,8 +852,7 @@ class Trellis2(nn.Module): if not_struct_mode: out = out.feats - if not_struct_mode: - out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) - if rule and orig_bsz > 1: - out = out.repeat(orig_bsz, 1, 1, 1) + out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) + if rule and orig_bsz > 1: + out = out.repeat(orig_bsz, 1, 1, 1) return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index cba6b3241..dcf8dcb98 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -95,7 +95,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): IO.AnyType.Input("shape_subs"), ], outputs=[ - IO.Mesh.Output("mesh"), + IO.Voxel.Output("voxel"), ] ) @@ -116,11 +116,9 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean - mesh = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5 - faces = torch.stack([m.faces for m in mesh]) - verts = torch.stack([m.vertices for m in mesh]) - mesh = Types.MESH(vertices=verts, faces=faces) - return IO.NodeOutput(mesh) + voxel = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5 + voxel = Types.VOXEL(voxel) + return IO.NodeOutput(voxel) class VaeDecodeStructureTrellis2(IO.ComfyNode): @classmethod @@ -377,7 +375,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): node_id="EmptyTextureLatentTrellis2", category="latent/3d", inputs=[ - IO.Voxel.Input("structure_output"), + IO.Voxel.Input("structure_or_coords"), IO.Latent.Input("shape_latent"), IO.Model.Input("model") ], @@ -388,16 +386,21 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_output, shape_latent, model): - # TODO - decoded = structure_output.data.unsqueeze(1) - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() - in_channels = 32 + def execute(cls, structure_or_coords, shape_latent, model): + channels = 32 + if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: + decoded = structure_or_coords.data.unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + + elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: + coords = structure_or_coords.int() shape_latent = shape_latent["samples"] + if shape_latent.ndim == 4: + shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels) shape_latent = shape_norm(shape_latent, coords) - latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) + latent = torch.randn(1, channels, coords.shape[0], 1) model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: @@ -406,9 +409,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords - model.model_options["transformer_options"]["generation_mode"] = "shape_generation" + model.model_options["transformer_options"]["generation_mode"] = "texture_generation" model.model_options["transformer_options"]["shape_slat"] = shape_latent - return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) + return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model) class EmptyStructureLatentTrellis2(IO.ComfyNode): From fe25190cae5e3a0bbaf20aab8e4f2d5f25dd0538 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 25 Mar 2026 02:40:15 +0200 Subject: [PATCH 46/59] add color support for save mesh --- comfy_extras/nodes_hunyuan3d.py | 17 +++++++++++-- comfy_extras/nodes_trellis2.py | 43 +++++++++++++++++++++++++++++---- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index df0c3e4b1..692834c2b 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -484,7 +484,7 @@ class VoxelToMesh(IO.ComfyNode): decode = execute # TODO: remove -def save_glb(vertices, faces, filepath, metadata=None): +def save_glb(vertices, faces, filepath, metadata=None, colors=None): """ Save PyTorch tensor vertices and faces as a GLB file without external dependencies. @@ -515,6 +515,13 @@ def save_glb(vertices, faces, filepath, metadata=None): indices_byte_length = len(indices_buffer) indices_byte_offset = len(vertices_buffer_padded) + if colors is not None: + colors_np = colors.cpu().numpy().astype(np.float32) + colors_buffer = colors_np.tobytes() + colors_byte_length = len(colors_buffer) + colors_byte_offset = len(buffer_data) + buffer_data += pad_to_4_bytes(colors_buffer) + gltf = { "asset": {"version": "2.0", "generator": "ComfyUI"}, "buffers": [ @@ -580,6 +587,11 @@ def save_glb(vertices, faces, filepath, metadata=None): "scene": 0 } + if colors is not None: + gltf["bufferViews"].append({"buffer": 0, "byteOffset": colors_byte_offset, "byteLength": colors_byte_length, "target": 34962}) + gltf["accessors"].append({"bufferView": 2, "byteOffset": 0, "componentType": 5126, "count": len(colors_np), "type": "VEC3"}) + gltf["meshes"][0]["primitives"][0]["attributes"]["COLOR_0"] = 2 + if metadata is not None: gltf["asset"]["extras"] = metadata @@ -669,7 +681,8 @@ class SaveGLB(IO.ComfyNode): # Handle Mesh input - save vertices and faces as GLB for i in range(mesh.vertices.shape[0]): f = f"{filename}_{counter:05}_.glb" - save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata) + v_colors = mesh.colors[i] if hasattr(mesh, "colors") and mesh.colors is not None else None + save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata, v_colors) results.append({ "filename": f, "subfolder": subfolder, diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index dcf8dcb98..4126fb536 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -45,6 +45,34 @@ def shape_norm(shape_latent, coords): samples = samples * std + mean return samples +def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution, chunk_size=4096): + """ + Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field. + Keeps chunking internal to prevent OOM crashes on large matrices. + """ + device = voxel_coords.device + + # Map Voxel Grid to Real 3D Space + origin = torch.tensor([-0.5, -0.5, -0.5], device=device) + voxel_size = 1.0 / resolution + voxel_pos = voxel_coords.float() * voxel_size + origin + + verts = mesh.vertices.to(device).squeeze(0) + v_colors = torch.zeros((verts.shape[0], 3), device=device) + + for i in range(0, verts.shape[0], chunk_size): + v_chunk = verts[i : i + chunk_size] + dists = torch.cdist(v_chunk, voxel_pos) + nearest_idx = torch.argmin(dists, dim=1) + v_colors[i : i + chunk_size] = voxel_colors[nearest_idx] + + final_colors = (v_colors * 0.5 + 0.5).clamp(0, 1).unsqueeze(0) + + out_mesh = copy.deepcopy(mesh) + out_mesh.colors = final_colors + + return out_mesh + class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def define_schema(cls): @@ -90,18 +118,20 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): node_id="VaeDecodeTextureTrellis", category="latent/3d", inputs=[ + IO.Mesh.Input("shape_mesh"), IO.Latent.Input("samples"), IO.Vae.Input("vae"), IO.AnyType.Input("shape_subs"), ], outputs=[ - IO.Voxel.Output("voxel"), + IO.Mesh.Output("mesh"), ] ) @classmethod - def execute(cls, samples, vae, shape_subs): + def execute(cls, shape_mesh, samples, vae, shape_subs): + resolution = 1024 patcher = vae.patcher device = comfy.model_management.get_torch_device() comfy.model_management.load_model_gpu(patcher) @@ -116,9 +146,12 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean - voxel = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5 - voxel = Types.VOXEL(voxel) - return IO.NodeOutput(voxel) + voxel = vae.decode_tex_slat(samples, shape_subs) + color_feats = voxel.feats[:, :3] + voxel_coords = voxel.coords[:, 1:] + + out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) + return IO.NodeOutput(out_mesh) class VaeDecodeStructureTrellis2(IO.ComfyNode): @classmethod From d2c37c222a9cc869e8aff1783a2add7b333c0c69 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 25 Mar 2026 17:03:52 +0200 Subject: [PATCH 47/59] pytorch -> scipy --- comfy_extras/nodes_trellis2.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 4126fb536..469b460eb 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -6,6 +6,7 @@ import logging from PIL import Image import numpy as np import torch +import scipy import copy shape_slat_normalization = { @@ -45,26 +46,30 @@ def shape_norm(shape_latent, coords): samples = samples * std + mean return samples -def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution, chunk_size=4096): +def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): """ Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field. - Keeps chunking internal to prevent OOM crashes on large matrices. """ - device = voxel_coords.device + device = comfy.model_management.vae_offload_device() - # Map Voxel Grid to Real 3D Space origin = torch.tensor([-0.5, -0.5, -0.5], device=device) voxel_size = 1.0 / resolution - voxel_pos = voxel_coords.float() * voxel_size + origin + # map voxels + voxel_pos = voxel_coords.to(device).float() * voxel_size + origin verts = mesh.vertices.to(device).squeeze(0) - v_colors = torch.zeros((verts.shape[0], 3), device=device) + voxel_colors = voxel_colors.to(device) - for i in range(0, verts.shape[0], chunk_size): - v_chunk = verts[i : i + chunk_size] - dists = torch.cdist(v_chunk, voxel_pos) - nearest_idx = torch.argmin(dists, dim=1) - v_colors[i : i + chunk_size] = voxel_colors[nearest_idx] + voxel_pos_np = voxel_pos.numpy() + verts_np = verts.numpy() + + tree = scipy.spatial.cKDTree(voxel_pos_np) + + # nearest neighbour k=1 + _, nearest_idx_np = tree.query(verts_np, k=1, workers=-1) + + nearest_idx = torch.from_numpy(nearest_idx_np).long() + v_colors = voxel_colors[nearest_idx] final_colors = (v_colors * 0.5 + 0.5).clamp(0, 1).unsqueeze(0) @@ -343,7 +348,11 @@ class Trellis2Conditioning(IO.ComfyNode): alpha_float = cropped_np[:, :, 3:4] composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) - cropped_img_tensor = torch.from_numpy(composite_np).movedim(-1, 0).unsqueeze(0).float() + # to match trellis2 code (quantize -> dequantize) + composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) + + cropped_img_tensor = torch.from_numpy(composite_uint8).float() / 255.0 + cropped_img_tensor = cropped_img_tensor.movedim(-1, 0).unsqueeze(0) conditioning = run_conditioning(clip_vision_model, cropped_img_tensor, include_1024=True) From 22dcd81fb3b8360f4d602dda97f62187921e97f3 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 25 Mar 2026 23:51:54 +0200 Subject: [PATCH 48/59] fixed color addition --- comfy/ldm/trellis2/model.py | 3 +-- comfy_extras/nodes_hunyuan3d.py | 3 +++ comfy_extras/nodes_trellis2.py | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 34aeba3e1..613f7ef50 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -782,7 +782,7 @@ class Trellis2(nn.Module): mode = "structure_generation" not_struct_mode = False - if is_1024 and mode == "shape_generation" and not is_512_run: + if is_1024 and not_struct_mode and not is_512_run: context = embeds sigmas = transformer_options.get("sigmas")[0].item() @@ -825,7 +825,6 @@ class Trellis2(nn.Module): x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) if mode == "shape_generation": - # TODO if is_512_run: out = self.img2shape_512(x_st, t_eval, c_eval) else: diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 692834c2b..ac91fe0a7 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -591,6 +591,9 @@ def save_glb(vertices, faces, filepath, metadata=None, colors=None): gltf["bufferViews"].append({"buffer": 0, "byteOffset": colors_byte_offset, "byteLength": colors_byte_length, "target": 34962}) gltf["accessors"].append({"bufferView": 2, "byteOffset": 0, "componentType": 5126, "count": len(colors_np), "type": "VEC3"}) gltf["meshes"][0]["primitives"][0]["attributes"]["COLOR_0"] = 2 + # Define a base material so Three.js actually activates vertex coloring + gltf["materials"] =[{"pbrMetallicRoughness": {"baseColorFactor": [1.0, 1.0, 1.0, 1.0]}}] + gltf["meshes"][0]["primitives"][0]["material"] = 0 if metadata is not None: gltf["asset"]["extras"] = metadata diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 469b460eb..ff95e8332 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -53,6 +53,7 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): device = comfy.model_management.vae_offload_device() origin = torch.tensor([-0.5, -0.5, -0.5], device=device) + # TODO: generic independent node? if so: figure how pass the resolution parameter voxel_size = 1.0 / resolution # map voxels From 55595697356038132c643bd6d5a5fa1a95cc9814 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 27 Mar 2026 20:52:23 +0200 Subject: [PATCH 49/59] .. --- comfy/ldm/trellis2/model.py | 4 ++-- comfy_extras/nodes_trellis2.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 613f7ef50..40646f369 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -843,8 +843,8 @@ class Trellis2(nn.Module): else: # structure orig_bsz = x.shape[0] if shape_rule: - x = x[0].unsqueeze(0) - timestep = timestep[0].unsqueeze(0) + x = x[1].unsqueeze(0) + timestep = timestep[1].unsqueeze(0) out = self.structure_model(x, timestep, context if not shape_rule else cond) if shape_rule: out = out.repeat(orig_bsz, 1, 1, 1, 1) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index ff95e8332..f58dbf592 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -441,7 +441,11 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): shape_latent = shape_latent["samples"] if shape_latent.ndim == 4: shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels) - shape_latent = shape_norm(shape_latent, coords) + + std = shape_slat_normalization["std"].to(shape_latent) + mean = shape_slat_normalization["mean"].to(shape_latent) + shape_latent = SparseTensor(feats = shape_latent, coords=coords) + shape_latent = (shape_latent - mean) / std latent = torch.randn(1, channels, coords.shape[0], 1) model = model.clone() From 72640888ffc6bcdda612432c27c8528cf0e76da8 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 30 Mar 2026 23:20:46 +0200 Subject: [PATCH 50/59] wrong normalization for the texture node --- comfy_extras/nodes_trellis2.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index f58dbf592..77e6a3add 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -442,11 +442,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): if shape_latent.ndim == 4: shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels) - std = shape_slat_normalization["std"].to(shape_latent) - mean = shape_slat_normalization["mean"].to(shape_latent) - shape_latent = SparseTensor(feats = shape_latent, coords=coords) - shape_latent = (shape_latent - mean) / std - latent = torch.randn(1, channels, coords.shape[0], 1) model = model.clone() model.model_options = model.model_options.copy() From 57b306464ecb2b87b93491198a632c720d0bb85c Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 3 Apr 2026 01:22:38 +0200 Subject: [PATCH 51/59] texture generation works --- comfy/ldm/trellis2/attention.py | 60 +++++++++++---------------------- comfy/ldm/trellis2/model.py | 7 ++-- comfy_extras/nodes_trellis2.py | 8 ++++- 3 files changed, 32 insertions(+), 43 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 681666430..d95b071b5 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -53,57 +53,37 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha ) return out - -# TODO repalce with optimized attention def scaled_dot_product_attention(*args, **kwargs): num_all_args = len(args) + len(kwargs) q = None if num_all_args == 1: - qkv = args[0] if len(args) > 0 else kwargs['qkv'] - + qkv = args[0] if len(args) > 0 else kwargs.get('qkv') elif num_all_args == 2: - q = args[0] if len(args) > 0 else kwargs['q'] - kv = args[1] if len(args) > 1 else kwargs['kv'] - + q = args[0] if len(args) > 0 else kwargs.get('q') + kv = args[1] if len(args) > 1 else kwargs.get('kv') elif num_all_args == 3: - q = args[0] if len(args) > 0 else kwargs['q'] - k = args[1] if len(args) > 1 else kwargs['k'] - v = args[2] if len(args) > 2 else kwargs['v'] + q = args[0] if len(args) > 0 else kwargs.get('q') + k = args[1] if len(args) > 1 else kwargs.get('k') + v = args[2] if len(args) > 2 else kwargs.get('v') if q is not None: - heads = q + heads = q.shape[2] else: - heads = qkv - heads = heads.shape[2] + heads = qkv.shape[3] - if optimized_attention.__name__ == 'attention_xformers': - if num_all_args == 1: - q, k, v = qkv.unbind(dim=2) - elif num_all_args == 2: - k, v = kv.unbind(dim=2) - #out = xops.memory_efficient_attention(q, k, v) - out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) - elif optimized_attention.__name__ == 'attention_flash': - if num_all_args == 2: - k, v = kv.unbind(dim=2) - out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) - elif optimized_attention.__name__ == 'attention_pytorch': - if num_all_args == 1: - q, k, v = qkv.unbind(dim=2) - elif num_all_args == 2: - k, v = kv.unbind(dim=2) - q = q.permute(0, 2, 1, 3) # [N, H, L, C] - k = k.permute(0, 2, 1, 3) # [N, H, L, C] - v = v.permute(0, 2, 1, 3) # [N, H, L, C] - out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) - out = out.permute(0, 2, 1, 3) # [N, L, H, C] - elif optimized_attention.__name__ == 'attention_basic': - if num_all_args == 1: - q, k, v = qkv.unbind(dim=2) - elif num_all_args == 2: - k, v = kv.unbind(dim=2) - out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs) + + out = out.permute(0, 2, 1, 3) return out diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 40646f369..7c6ffdd69 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -788,7 +788,10 @@ class Trellis2(nn.Module): sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: timestep *= 1000.0 - cond = context.chunk(2)[1] + if context.size(0) > 1: + cond = context.chunk(2)[1] + else: + cond = context shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] @@ -836,7 +839,7 @@ class Trellis2(nn.Module): if slat is None: raise ValueError("shape_slat can't be None") - base_slat_feats = slat.feats[:N] + base_slat_feats = slat[:N] slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device) x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1)) out = self.shape2txt(x_st, t_eval, c_eval) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 77e6a3add..088cdd3f1 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -72,7 +72,13 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): nearest_idx = torch.from_numpy(nearest_idx_np).long() v_colors = voxel_colors[nearest_idx] - final_colors = (v_colors * 0.5 + 0.5).clamp(0, 1).unsqueeze(0) + # to [0, 1] + srgb_colors = (v_colors * 0.5 + 0.5).clamp(0, 1) + + # to Linear RGB (required for GLTF) + linear_colors = torch.pow(srgb_colors, 2.2) + + final_colors = linear_colors.unsqueeze(0) out_mesh = copy.deepcopy(mesh) out_mesh.colors = final_colors From 2cb06431e8e481b41c2c4da2aa92ba30ea07d66c Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 7 Apr 2026 23:02:33 +0200 Subject: [PATCH 52/59] fix for conditioning --- comfy/ldm/trellis2/model.py | 2 +- comfy_extras/nodes_trellis2.py | 29 ++++++++++++++--------------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 7c6ffdd69..ea7ada9f8 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -671,7 +671,7 @@ class SparseStructureFlowModel(nn.Module): coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') coords = torch.stack(coords, dim=-1).reshape(-1, 3) rope_phases = pos_embedder(coords) - self.register_buffer("rope_phases", rope_phases) + self.register_buffer("rope_phases", rope_phases, persistent=False) if pe_mode != "rope": self.rope_phases = None diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 088cdd3f1..d3f5e4940 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -2,7 +2,6 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy.ldm.trellis2.vae import SparseTensor import comfy.model_management -import logging from PIL import Image import numpy as np import torch @@ -250,28 +249,28 @@ class Trellis2UpsampleCascade(IO.ComfyNode): return IO.NodeOutput(final_coords,) -dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) -dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) +dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) +dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) def run_conditioning(model, cropped_img_tensor, include_1024=True): model_internal = model.model device = comfy.model_management.intermediate_device() torch_device = comfy.model_management.get_torch_device() - img_t = cropped_img_tensor.to(torch_device) - - def prepare_tensor(img, size): - resized = torch.nn.functional.interpolate(img, size=(size, size), mode='bicubic', align_corners=False).clamp(0.0, 1.0) - return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) + def prepare_tensor(pil_img, size): + resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS) + img_np = np.array(resized_pil).astype(np.float32) / 255.0 + img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device) + return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device) model_internal.image_size = 512 - input_512 = prepare_tensor(img_t, 512) + input_512 = prepare_tensor(cropped_img_tensor, 512) cond_512 = model_internal(input_512)[0] cond_1024 = None if include_1024: model_internal.image_size = 1024 - input_1024 = prepare_tensor(img_t, 1024) + input_1024 = prepare_tensor(cropped_img_tensor, 1024) cond_1024 = model_internal(input_1024)[0] conditioning = { @@ -341,14 +340,15 @@ class Trellis2Conditioning(IO.ComfyNode): crop_x2 = int(center_x + size // 2) crop_y2 = int(center_y + size // 2) - rgba_pil = Image.fromarray(rgba_np, 'RGBA') + rgba_pil = Image.fromarray(rgba_np) cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)) cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 else: + import logging logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") cropped_np = rgba_np.astype(np.float32) / 255.0 - bg_colors = {"black": [0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} + bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32) fg = cropped_np[:, :, :3] @@ -358,10 +358,9 @@ class Trellis2Conditioning(IO.ComfyNode): # to match trellis2 code (quantize -> dequantize) composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) - cropped_img_tensor = torch.from_numpy(composite_uint8).float() / 255.0 - cropped_img_tensor = cropped_img_tensor.movedim(-1, 0).unsqueeze(0) + cropped_pil = Image.fromarray(composite_uint8) - conditioning = run_conditioning(clip_vision_model, cropped_img_tensor, include_1024=True) + conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True) embeds = conditioning["cond_1024"] positive = [[conditioning["cond_512"], {"embeds": embeds}]] From 0ebeac98a78885d4c13c09691e027fb141d9e581 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 8 Apr 2026 19:08:26 +0200 Subject: [PATCH 53/59] removed unnecessary vae float32 upcast --- comfy/ldm/trellis2/vae.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 2a18c496a..c42ad8d2f 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -2,7 +2,6 @@ import math import torch import numpy as np import torch.nn as nn -import comfy.model_management import torch.nn.functional as F from fractions import Fraction from dataclasses import dataclass @@ -78,7 +77,10 @@ class LayerNorm32(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: x_dtype = x.dtype x = x.to(torch.float32) - o = super().forward(x) + w = self.weight.to(torch.float32) + b = self.bias.to(torch.float32) if self.bias is not None else None + + o = F.layer_norm(x, self.normalized_shape, w, b, self.eps) return o.to(x_dtype) class SparseConvNeXtBlock3d(nn.Module): @@ -102,8 +104,7 @@ class SparseConvNeXtBlock3d(nn.Module): def _forward(self, x): h = self.conv(x) - norm = self.norm.to(torch.float32) - h = h.replace(norm(h.feats)) + h = h.replace(self.norm(h.feats)) h = h.replace(self.mlp(h.feats)) return h + x @@ -213,15 +214,13 @@ class SparseResBlockC2S3d(nn.Module): dtype = next(self.to_subdiv.parameters()).dtype x = x.to(dtype) subdiv = self.to_subdiv(x) - norm1 = self.norm1.to(torch.float32) - norm2 = self.norm2.to(torch.float32) - h = x.replace(norm1(x.feats)) + h = x.replace(self.norm1(x.feats)) h = h.replace(F.silu(h.feats)) h = self.conv1(h) subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None h = self.updown(h, subdiv_binarized) x = self.updown(x, subdiv_binarized) - h = h.replace(norm2(h.feats)) + h = h.replace(self.norm2(h.feats)) h = h.replace(F.silu(h.feats)) h = self.conv2(h) h = h + self.skip_connection(x) @@ -1300,8 +1299,6 @@ class ResBlock3d(nn.Module): self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: - self.norm1 = self.norm1.to(torch.float32) - self.norm2 = self.norm2.to(torch.float32) h = self.norm1(x) h = F.silu(h) dtype = next(self.conv1.parameters()).dtype @@ -1381,8 +1378,7 @@ class SparseStructureDecoder(nn.Module): for block in self.blocks: h = block(h) - h = h.to(torch.float32) - self.out_layer = self.out_layer.to(torch.float32) + h = h.type(x.dtype) h = self.out_layer(h) return h From ea255543e648000ed987576494dbc227320b61d9 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 10 Apr 2026 14:24:07 +0200 Subject: [PATCH 54/59] structure generation works --- comfy/image_encoders/dino3.py | 8 ++++++-- comfy/ldm/trellis2/vae.py | 2 +- comfy_extras/nodes_trellis2.py | 4 ++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index ff17d78d6..145bd5490 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -1,6 +1,7 @@ import math import torch import torch.nn as nn +import torch.nn.functional as F import comfy.model_management from comfy.ldm.modules.attention import optimized_attention_for_device @@ -274,8 +275,11 @@ class DINOv3ViTModel(nn.Module): position_embeddings=position_embeddings, ) - norm = self.norm.to(hidden_states.device) - sequence_output = norm(hidden_states) + if kwargs.get("skip_norm_elementwise", False): + sequence_output= F.layer_norm(hidden_states, hidden_states.shape[-1:]) + else: + norm = self.norm.to(hidden_states.device) + sequence_output = norm(hidden_states) pooled_output = sequence_output[:, 0, :] return sequence_output, None, pooled_output, None diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index c42ad8d2f..cd37ccd30 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -77,7 +77,7 @@ class LayerNorm32(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: x_dtype = x.dtype x = x.to(torch.float32) - w = self.weight.to(torch.float32) + w = self.weight.to(torch.float32) if self.weight is not None else None b = self.bias.to(torch.float32) if self.bias is not None else None o = F.layer_norm(x, self.normalized_shape, w, b, self.eps) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index d3f5e4940..56ce4f5ea 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -265,13 +265,13 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): model_internal.image_size = 512 input_512 = prepare_tensor(cropped_img_tensor, 512) - cond_512 = model_internal(input_512)[0] + cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0] cond_1024 = None if include_1024: model_internal.image_size = 1024 input_1024 = prepare_tensor(cropped_img_tensor, 1024) - cond_1024 = model_internal(input_1024)[0] + cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] conditioning = { 'cond_512': cond_512.to(device), From 4e14d42da1c26f48af7836c3c2eff8aa8cc8d4f5 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 10 Apr 2026 16:12:23 +0200 Subject: [PATCH 55/59] comfy ops + color support in postprocess --- comfy/ldm/trellis2/model.py | 105 +++++++++++++++++++-------------- comfy/ldm/trellis2/vae.py | 14 +++-- comfy_extras/nodes_trellis2.py | 57 +++++++++--------- 3 files changed, 98 insertions(+), 78 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index ea7ada9f8..a613fb325 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -14,12 +14,12 @@ class SparseGELU(nn.GELU): return input.replace(super().forward(input.feats)) class SparseFeedForwardNet(nn.Module): - def __init__(self, channels: int, mlp_ratio: float = 4.0): + def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None): super().__init__() self.mlp = nn.Sequential( - SparseLinear(channels, int(channels * mlp_ratio)), + SparseLinear(channels, int(channels * mlp_ratio), device=device, dtype=dtype, operations=operations), SparseGELU(approximate="tanh"), - SparseLinear(int(channels * mlp_ratio), channels), + SparseLinear(int(channels * mlp_ratio), channels, device=device, dtype=dtype, operations=operations), ) def forward(self, x: VarLenTensor) -> VarLenTensor: @@ -37,10 +37,10 @@ class LayerNorm32(nn.LayerNorm): class SparseMultiHeadRMSNorm(nn.Module): - def __init__(self, dim: int, heads: int): + def __init__(self, dim: int, heads: int, device, dtype): super().__init__() self.scale = dim ** 0.5 - self.gamma = nn.Parameter(torch.ones(heads, dim)) + self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype)) def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: x_type = x.dtype @@ -56,14 +56,15 @@ class SparseRotaryPositionEmbedder(nn.Module): self, head_dim: int, dim: int = 3, - rope_freq: Tuple[float, float] = (1.0, 10000.0) + rope_freq: Tuple[float, float] = (1.0, 10000.0), + device=None ): super().__init__() self.head_dim = head_dim self.dim = dim self.rope_freq = rope_freq self.freq_dim = head_dim // 2 // dim - self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32, device=device) / self.freq_dim self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor: @@ -148,6 +149,7 @@ class SparseMultiHeadAttention(nn.Module): use_rope: bool = False, rope_freq: Tuple[int, int] = (1.0, 10000.0), qk_rms_norm: bool = False, + device=None, dtype=None, operations=None ): super().__init__() @@ -163,19 +165,19 @@ class SparseMultiHeadAttention(nn.Module): self.qk_rms_norm = qk_rms_norm if self._type == "self": - self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, device=device, dtype=dtype) else: - self.to_q = nn.Linear(channels, channels, bias=qkv_bias) - self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype) + self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype) if self.qk_rms_norm: - self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) - self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) + self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) - self.to_out = nn.Linear(channels, channels) + self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype) if use_rope: - self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq) + self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq, device=device) @staticmethod def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: @@ -267,14 +269,14 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): qk_rms_norm_cross: bool = False, qkv_bias: bool = True, share_mod: bool = False, - + device=None, dtype=None, operations=None ): super().__init__() self.use_checkpoint = use_checkpoint self.share_mod = share_mod - self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) - self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) self.self_attn = SparseMultiHeadAttention( channels, num_heads=num_heads, @@ -286,6 +288,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): use_rope=use_rope, rope_freq=rope_freq, qk_rms_norm=qk_rms_norm, + device=device, dtype=dtype, operations=operations ) self.cross_attn = SparseMultiHeadAttention( channels, @@ -295,18 +298,20 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): attn_mode="full", qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, + device=device, dtype=dtype, operations=operations ) self.mlp = SparseFeedForwardNet( channels, mlp_ratio=mlp_ratio, + device=device, dtype=dtype, operations=operations ) if not share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) + operations.Linear(channels, 6 * channels, bias=True, device=device, dtype=dtype) ) else: - self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5) def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: if self.share_mod: @@ -376,10 +381,10 @@ class SLatFlowModel(nn.Module): if share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), - nn.Linear(model_channels, 6 * model_channels, bias=True) + operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype) ) - self.input_layer = SparseLinear(in_channels, model_channels) + self.input_layer = SparseLinear(in_channels, model_channels, device=device, dtype=dtype, operations=operations) self.blocks = nn.ModuleList([ ModulatedSparseTransformerCrossBlock( @@ -394,11 +399,12 @@ class SLatFlowModel(nn.Module): share_mod=self.share_mod, qk_rms_norm=self.qk_rms_norm, qk_rms_norm_cross=self.qk_rms_norm_cross, + device=device, dtype=dtype, operations=operations ) for _ in range(num_blocks) ]) - self.out_layer = SparseLinear(model_channels, out_channels) + self.out_layer = SparseLinear(model_channels, out_channels, device=device, dtype=dtype, operations=operations) @property def device(self) -> torch.device: @@ -438,22 +444,22 @@ class SLatFlowModel(nn.Module): return h class FeedForwardNet(nn.Module): - def __init__(self, channels: int, mlp_ratio: float = 4.0): + def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None): super().__init__() self.mlp = nn.Sequential( - nn.Linear(channels, int(channels * mlp_ratio)), + operations.Linear(channels, int(channels * mlp_ratio), device=device, dtype=dtype), nn.GELU(approximate="tanh"), - nn.Linear(int(channels * mlp_ratio), channels), + operations.Linear(int(channels * mlp_ratio), channels, device=device, dtype=dtype), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp(x) class MultiHeadRMSNorm(nn.Module): - def __init__(self, dim: int, heads: int): + def __init__(self, dim: int, heads: int, device=None, dtype=None): super().__init__() self.scale = dim ** 0.5 - self.gamma = nn.Parameter(torch.ones(heads, dim)) + self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype)) def forward(self, x: torch.Tensor) -> torch.Tensor: return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) @@ -473,6 +479,7 @@ class MultiHeadAttention(nn.Module): use_rope: bool = False, rope_freq: Tuple[float, float] = (1.0, 10000.0), qk_rms_norm: bool = False, + device=None, dtype=None, operations=None ): super().__init__() @@ -488,16 +495,16 @@ class MultiHeadAttention(nn.Module): self.qk_rms_norm = qk_rms_norm if self._type == "self": - self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device) else: - self.to_q = nn.Linear(channels, channels, bias=qkv_bias) - self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype) + self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype) if self.qk_rms_norm: - self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) - self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) - self.to_out = nn.Linear(channels, channels) + self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype) def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: B, L, C = x.shape @@ -554,13 +561,14 @@ class ModulatedTransformerCrossBlock(nn.Module): qk_rms_norm_cross: bool = False, qkv_bias: bool = True, share_mod: bool = False, + device=None, dtype=None, operations=None ): super().__init__() self.use_checkpoint = use_checkpoint self.share_mod = share_mod - self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) - self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) self.self_attn = MultiHeadAttention( channels, num_heads=num_heads, @@ -572,6 +580,7 @@ class ModulatedTransformerCrossBlock(nn.Module): use_rope=use_rope, rope_freq=rope_freq, qk_rms_norm=qk_rms_norm, + device=device, dtype=dtype, operations=operations ) self.cross_attn = MultiHeadAttention( channels, @@ -581,18 +590,20 @@ class ModulatedTransformerCrossBlock(nn.Module): attn_mode="full", qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, + device=device, dtype=dtype, operations=operations ) self.mlp = FeedForwardNet( channels, mlp_ratio=mlp_ratio, + device=device, dtype=dtype, operations=operations ) if not share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) + operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device) ) else: - self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5) def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: if self.share_mod: @@ -659,16 +670,17 @@ class SparseStructureFlowModel(nn.Module): self.qk_rms_norm = qk_rms_norm self.qk_rms_norm_cross = qk_rms_norm_cross self.dtype = dtype + self.device = device - self.t_embedder = TimestepEmbedder(model_channels, operations=operations) + self.t_embedder = TimestepEmbedder(model_channels, dtype=dtype, device=device, operations=operations) if share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), - nn.Linear(model_channels, 6 * model_channels, bias=True) + operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype) ) - pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3) - coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') + pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3, device=device) + coords = torch.meshgrid(*[torch.arange(res, device=self.device, dtype=dtype) for res in [resolution] * 3], indexing='ij') coords = torch.stack(coords, dim=-1).reshape(-1, 3) rope_phases = pos_embedder(coords) self.register_buffer("rope_phases", rope_phases, persistent=False) @@ -676,7 +688,7 @@ class SparseStructureFlowModel(nn.Module): if pe_mode != "rope": self.rope_phases = None - self.input_layer = nn.Linear(in_channels, model_channels) + self.input_layer = operations.Linear(in_channels, model_channels, device=device, dtype=dtype) self.blocks = nn.ModuleList([ ModulatedTransformerCrossBlock( @@ -691,11 +703,12 @@ class SparseStructureFlowModel(nn.Module): share_mod=share_mod, qk_rms_norm=self.qk_rms_norm, qk_rms_norm_cross=self.qk_rms_norm_cross, + device=device, dtype=dtype, operations=operations ) for _ in range(num_blocks) ]) - self.out_layer = nn.Linear(model_channels, out_channels) + self.out_layer = operations.Linear(model_channels, out_channels, device=device, dtype=dtype) def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3) @@ -745,6 +758,7 @@ class Trellis2(nn.Module): super().__init__() self.dtype = dtype + operations = operations or nn # for some reason it passes num_heads = -1 if num_heads == -1: num_heads = 12 @@ -772,6 +786,7 @@ class Trellis2(nn.Module): coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") is_512_run = False + timestep = timestep.to(self.dtype) if mode == "shape_generation_512": is_512_run = True mode = "shape_generation" diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index cd37ccd30..30f902868 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -962,13 +962,17 @@ def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: feats = input.feats.unbind(dim) return [input.replace(f) for f in feats] -class SparseLinear(nn.Linear): - def __init__(self, in_features, out_features, bias=True): - super(SparseLinear, self).__init__(in_features, out_features, bias) +# allow operations.Linear inheritance +class SparseLinear: + def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None, operations=nn, *args, **kwargs): + class _SparseLinear(operations.Linear): + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): + super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) - def forward(self, input: VarLenTensor) -> VarLenTensor: - return input.replace(super().forward(input.feats)) + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + return _SparseLinear(in_features, out_features, bias=bias, device=device, dtype=dtype, *args, **kwargs) MIX_PRECISION_MODULES = ( nn.Conv1d, diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 56ce4f5ea..1bf7c55b8 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -481,21 +481,25 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) -def simplify_fn(vertices, faces, target=100000): - is_batched = vertices.ndim == 3 - if is_batched: - v_list, f_list = [], [] +def simplify_fn(vertices, faces, colors=None, target=100000): + if vertices.ndim == 3: + v_list, f_list, c_list = [], [], [] for i in range(vertices.shape[0]): - v_i, f_i = simplify_fn(vertices[i], faces[i], target) + c_in = colors[i] if colors is not None else None + v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target) v_list.append(v_i) f_list.append(f_i) - return torch.stack(v_list), torch.stack(f_list) + if c_i is not None: + c_list.append(c_i) + + c_out = torch.stack(c_list) if len(c_list) > 0 else None + return torch.stack(v_list), torch.stack(f_list), c_out if faces.shape[0] <= target: - return vertices, faces + return vertices, faces, colors device = vertices.device - target_v = target / 2.0 + target_v = max(target / 4.0, 1.0) min_v = vertices.min(dim=0)[0] max_v = vertices.max(dim=0)[0] @@ -510,14 +514,17 @@ def simplify_fn(vertices, faces, target=100000): new_vertices = torch.zeros((num_cells, 3), dtype=vertices.dtype, device=device) counts = torch.zeros((num_cells, 1), dtype=vertices.dtype, device=device) - new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices) counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1])) - new_vertices = new_vertices / counts.clamp(min=1) - new_faces = inverse_indices[faces] + new_colors = None + if colors is not None: + new_colors = torch.zeros((num_cells, colors.shape[1]), dtype=colors.dtype, device=device) + new_colors.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, colors.shape[1]), colors) + new_colors = new_colors / counts.clamp(min=1) + new_faces = inverse_indices[faces] valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \ (new_faces[:, 1] != new_faces[:, 2]) & \ (new_faces[:, 2] != new_faces[:, 0]) @@ -527,7 +534,10 @@ def simplify_fn(vertices, faces, target=100000): final_vertices = new_vertices[unique_face_indices] final_faces = inv_face.reshape(-1, 3) - return final_vertices, final_faces + # assign colors + final_colors = new_colors[unique_face_indices] if new_colors is not None else None + + return final_vertices, final_faces, final_colors def fill_holes_fn(vertices, faces, max_perimeter=0.03): is_batched = vertices.ndim == 3 @@ -610,19 +620,6 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): return v, f -def make_double_sided(vertices, faces): - is_batched = vertices.ndim == 3 - if is_batched: - f_list =[] - for i in range(faces.shape[0]): - f_inv = faces[i][:,[0, 2, 1]] - f_list.append(torch.cat([faces[i], f_inv], dim=0)) - return vertices, torch.stack(f_list) - - faces_inv = faces[:, [0, 2, 1]] - faces_double = torch.cat([faces, faces_inv], dim=0) - return vertices, faces_double - class PostProcessMesh(IO.ComfyNode): @classmethod def define_schema(cls): @@ -641,19 +638,23 @@ class PostProcessMesh(IO.ComfyNode): @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): + # TODO: batched mode may break mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces + colors = None + if hasattr(mesh, "colors"): + colors = mesh.colors if fill_holes_perimeter > 0: verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter) if simplify > 0 and faces.shape[0] > simplify: - verts, faces = simplify_fn(verts, faces, target=simplify) - - verts, faces = make_double_sided(verts, faces) + verts, faces, colors = simplify_fn(verts, faces, target=simplify, colors=colors) mesh.vertices = verts mesh.faces = faces + if colors is not None: + mesh.colors = None return IO.NodeOutput(mesh) class Trellis2Extension(ComfyExtension): From 937faafe21c70036df67659bef321619adb99eaa Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 10 Apr 2026 16:33:19 +0200 Subject: [PATCH 56/59] corrected simplification logic --- comfy_extras/nodes_trellis2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 1bf7c55b8..4aaa2c5f4 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -639,18 +639,19 @@ class PostProcessMesh(IO.ComfyNode): @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): # TODO: batched mode may break - mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces colors = None if hasattr(mesh, "colors"): colors = mesh.colors + actual_face_count = faces.shape[1] if faces.ndim == 3 else faces.shape[0] if fill_holes_perimeter > 0: verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter) - if simplify > 0 and faces.shape[0] > simplify: + if simplify > 0 and actual_face_count > simplify: verts, faces, colors = simplify_fn(verts, faces, target=simplify, colors=colors) + mesh = type(mesh)(vertices=verts, faces=faces) mesh.vertices = verts mesh.faces = faces if colors is not None: From 243691c258803ebbf22c9f7fc38d7baa3e4239f2 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 10 Apr 2026 19:18:14 +0200 Subject: [PATCH 57/59] texture fixes --- comfy/supported_models.py | 2 ++ comfy_extras/nodes_trellis2.py | 17 ++++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 20a60194b..f445edf66 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1270,6 +1270,8 @@ class Trellis2(supported_models_base.BASE): latent_format = latent_formats.Trellis2 vae_key_prefix = ["vae."] clip_vision_prefix = "conditioner.main_image_encoder.model." + # this is only needed for the texture model + supported_inference_dtypes = [torch.bfloat16, torch.float32] def get_model(self, state_dict, prefix="", device=None): return model_base.Trellis2(self, device=device) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 4aaa2c5f4..42fe7f707 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -620,6 +620,19 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): return v, f + +def make_double_sided(vertices, faces): + is_batched = vertices.ndim == 3 + if is_batched: + f_list = [] + for i in range(faces.shape[0]): + f_inv = faces[i][:, [0, 2, 1]] + f_list.append(torch.cat([faces[i], f_inv], dim=0)) + return vertices, torch.stack(f_list) + + faces_inv = faces[:, [0, 2, 1]] + return vertices, torch.cat([faces, faces_inv], dim=0) + class PostProcessMesh(IO.ComfyNode): @classmethod def define_schema(cls): @@ -651,11 +664,13 @@ class PostProcessMesh(IO.ComfyNode): if simplify > 0 and actual_face_count > simplify: verts, faces, colors = simplify_fn(verts, faces, target=simplify, colors=colors) + verts, faces = make_double_sided(verts, faces) + mesh = type(mesh)(vertices=verts, faces=faces) mesh.vertices = verts mesh.faces = faces if colors is not None: - mesh.colors = None + mesh.colors = colors return IO.NodeOutput(mesh) class Trellis2Extension(ComfyExtension): From a1364a7b0076343eb9ae10e9c85c602470b2147b Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 12 Apr 2026 20:14:47 +0200 Subject: [PATCH 58/59] small final change --- comfy_extras/nodes_trellis2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 42fe7f707..3479d5410 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -72,7 +72,7 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): v_colors = voxel_colors[nearest_idx] # to [0, 1] - srgb_colors = (v_colors * 0.5 + 0.5).clamp(0, 1) + srgb_colors = v_colors.clamp(0, 1)#(v_colors * 0.5 + 0.5).clamp(0, 1) # to Linear RGB (required for GLTF) linear_colors = torch.pow(srgb_colors, 2.2) From f4ae7b8391b6c977f17312fb0eb3446251ac0d35 Mon Sep 17 00:00:00 2001 From: "Yousef R. Gamaleldin" <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 12 Apr 2026 20:21:19 +0200 Subject: [PATCH 59/59] Fix return statement --- comfy/model_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 157bdd929..5f258178f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1589,6 +1589,7 @@ class WAN21_SCAIL(WAN21): pose_latents = kwargs.get("pose_video_latent", None) if pose_latents is not None: out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]] + return out class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):