From 4e14d42da1c26f48af7836c3c2eff8aa8cc8d4f5 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 10 Apr 2026 16:12:23 +0200 Subject: [PATCH] comfy ops + color support in postprocess --- comfy/ldm/trellis2/model.py | 105 +++++++++++++++++++-------------- comfy/ldm/trellis2/vae.py | 14 +++-- comfy_extras/nodes_trellis2.py | 57 +++++++++--------- 3 files changed, 98 insertions(+), 78 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index ea7ada9f8..a613fb325 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -14,12 +14,12 @@ class SparseGELU(nn.GELU): return input.replace(super().forward(input.feats)) class SparseFeedForwardNet(nn.Module): - def __init__(self, channels: int, mlp_ratio: float = 4.0): + def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None): super().__init__() self.mlp = nn.Sequential( - SparseLinear(channels, int(channels * mlp_ratio)), + SparseLinear(channels, int(channels * mlp_ratio), device=device, dtype=dtype, operations=operations), SparseGELU(approximate="tanh"), - SparseLinear(int(channels * mlp_ratio), channels), + SparseLinear(int(channels * mlp_ratio), channels, device=device, dtype=dtype, operations=operations), ) def forward(self, x: VarLenTensor) -> VarLenTensor: @@ -37,10 +37,10 @@ class LayerNorm32(nn.LayerNorm): class SparseMultiHeadRMSNorm(nn.Module): - def __init__(self, dim: int, heads: int): + def __init__(self, dim: int, heads: int, device, dtype): super().__init__() self.scale = dim ** 0.5 - self.gamma = nn.Parameter(torch.ones(heads, dim)) + self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype)) def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: x_type = x.dtype @@ -56,14 +56,15 @@ class SparseRotaryPositionEmbedder(nn.Module): self, head_dim: int, dim: int = 3, - rope_freq: Tuple[float, float] = (1.0, 10000.0) + rope_freq: Tuple[float, float] = (1.0, 10000.0), + device=None ): super().__init__() self.head_dim = head_dim self.dim = dim self.rope_freq = rope_freq self.freq_dim = head_dim // 2 // dim - self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32, device=device) / self.freq_dim self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor: @@ -148,6 +149,7 @@ class SparseMultiHeadAttention(nn.Module): use_rope: bool = False, rope_freq: Tuple[int, int] = (1.0, 10000.0), qk_rms_norm: bool = False, + device=None, dtype=None, operations=None ): super().__init__() @@ -163,19 +165,19 @@ class SparseMultiHeadAttention(nn.Module): self.qk_rms_norm = qk_rms_norm if self._type == "self": - self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, device=device, dtype=dtype) else: - self.to_q = nn.Linear(channels, channels, bias=qkv_bias) - self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype) + self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype) if self.qk_rms_norm: - self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) - self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) + self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) - self.to_out = nn.Linear(channels, channels) + self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype) if use_rope: - self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq) + self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq, device=device) @staticmethod def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: @@ -267,14 +269,14 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): qk_rms_norm_cross: bool = False, qkv_bias: bool = True, share_mod: bool = False, - + device=None, dtype=None, operations=None ): super().__init__() self.use_checkpoint = use_checkpoint self.share_mod = share_mod - self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) - self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) self.self_attn = SparseMultiHeadAttention( channels, num_heads=num_heads, @@ -286,6 +288,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): use_rope=use_rope, rope_freq=rope_freq, qk_rms_norm=qk_rms_norm, + device=device, dtype=dtype, operations=operations ) self.cross_attn = SparseMultiHeadAttention( channels, @@ -295,18 +298,20 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): attn_mode="full", qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, + device=device, dtype=dtype, operations=operations ) self.mlp = SparseFeedForwardNet( channels, mlp_ratio=mlp_ratio, + device=device, dtype=dtype, operations=operations ) if not share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) + operations.Linear(channels, 6 * channels, bias=True, device=device, dtype=dtype) ) else: - self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5) def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: if self.share_mod: @@ -376,10 +381,10 @@ class SLatFlowModel(nn.Module): if share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), - nn.Linear(model_channels, 6 * model_channels, bias=True) + operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype) ) - self.input_layer = SparseLinear(in_channels, model_channels) + self.input_layer = SparseLinear(in_channels, model_channels, device=device, dtype=dtype, operations=operations) self.blocks = nn.ModuleList([ ModulatedSparseTransformerCrossBlock( @@ -394,11 +399,12 @@ class SLatFlowModel(nn.Module): share_mod=self.share_mod, qk_rms_norm=self.qk_rms_norm, qk_rms_norm_cross=self.qk_rms_norm_cross, + device=device, dtype=dtype, operations=operations ) for _ in range(num_blocks) ]) - self.out_layer = SparseLinear(model_channels, out_channels) + self.out_layer = SparseLinear(model_channels, out_channels, device=device, dtype=dtype, operations=operations) @property def device(self) -> torch.device: @@ -438,22 +444,22 @@ class SLatFlowModel(nn.Module): return h class FeedForwardNet(nn.Module): - def __init__(self, channels: int, mlp_ratio: float = 4.0): + def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None): super().__init__() self.mlp = nn.Sequential( - nn.Linear(channels, int(channels * mlp_ratio)), + operations.Linear(channels, int(channels * mlp_ratio), device=device, dtype=dtype), nn.GELU(approximate="tanh"), - nn.Linear(int(channels * mlp_ratio), channels), + operations.Linear(int(channels * mlp_ratio), channels, device=device, dtype=dtype), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp(x) class MultiHeadRMSNorm(nn.Module): - def __init__(self, dim: int, heads: int): + def __init__(self, dim: int, heads: int, device=None, dtype=None): super().__init__() self.scale = dim ** 0.5 - self.gamma = nn.Parameter(torch.ones(heads, dim)) + self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype)) def forward(self, x: torch.Tensor) -> torch.Tensor: return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) @@ -473,6 +479,7 @@ class MultiHeadAttention(nn.Module): use_rope: bool = False, rope_freq: Tuple[float, float] = (1.0, 10000.0), qk_rms_norm: bool = False, + device=None, dtype=None, operations=None ): super().__init__() @@ -488,16 +495,16 @@ class MultiHeadAttention(nn.Module): self.qk_rms_norm = qk_rms_norm if self._type == "self": - self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device) else: - self.to_q = nn.Linear(channels, channels, bias=qkv_bias) - self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype) + self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype) if self.qk_rms_norm: - self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) - self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) - self.to_out = nn.Linear(channels, channels) + self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype) def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: B, L, C = x.shape @@ -554,13 +561,14 @@ class ModulatedTransformerCrossBlock(nn.Module): qk_rms_norm_cross: bool = False, qkv_bias: bool = True, share_mod: bool = False, + device=None, dtype=None, operations=None ): super().__init__() self.use_checkpoint = use_checkpoint self.share_mod = share_mod - self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) - self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) self.self_attn = MultiHeadAttention( channels, num_heads=num_heads, @@ -572,6 +580,7 @@ class ModulatedTransformerCrossBlock(nn.Module): use_rope=use_rope, rope_freq=rope_freq, qk_rms_norm=qk_rms_norm, + device=device, dtype=dtype, operations=operations ) self.cross_attn = MultiHeadAttention( channels, @@ -581,18 +590,20 @@ class ModulatedTransformerCrossBlock(nn.Module): attn_mode="full", qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, + device=device, dtype=dtype, operations=operations ) self.mlp = FeedForwardNet( channels, mlp_ratio=mlp_ratio, + device=device, dtype=dtype, operations=operations ) if not share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) + operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device) ) else: - self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5) def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: if self.share_mod: @@ -659,16 +670,17 @@ class SparseStructureFlowModel(nn.Module): self.qk_rms_norm = qk_rms_norm self.qk_rms_norm_cross = qk_rms_norm_cross self.dtype = dtype + self.device = device - self.t_embedder = TimestepEmbedder(model_channels, operations=operations) + self.t_embedder = TimestepEmbedder(model_channels, dtype=dtype, device=device, operations=operations) if share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), - nn.Linear(model_channels, 6 * model_channels, bias=True) + operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype) ) - pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3) - coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') + pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3, device=device) + coords = torch.meshgrid(*[torch.arange(res, device=self.device, dtype=dtype) for res in [resolution] * 3], indexing='ij') coords = torch.stack(coords, dim=-1).reshape(-1, 3) rope_phases = pos_embedder(coords) self.register_buffer("rope_phases", rope_phases, persistent=False) @@ -676,7 +688,7 @@ class SparseStructureFlowModel(nn.Module): if pe_mode != "rope": self.rope_phases = None - self.input_layer = nn.Linear(in_channels, model_channels) + self.input_layer = operations.Linear(in_channels, model_channels, device=device, dtype=dtype) self.blocks = nn.ModuleList([ ModulatedTransformerCrossBlock( @@ -691,11 +703,12 @@ class SparseStructureFlowModel(nn.Module): share_mod=share_mod, qk_rms_norm=self.qk_rms_norm, qk_rms_norm_cross=self.qk_rms_norm_cross, + device=device, dtype=dtype, operations=operations ) for _ in range(num_blocks) ]) - self.out_layer = nn.Linear(model_channels, out_channels) + self.out_layer = operations.Linear(model_channels, out_channels, device=device, dtype=dtype) def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3) @@ -745,6 +758,7 @@ class Trellis2(nn.Module): super().__init__() self.dtype = dtype + operations = operations or nn # for some reason it passes num_heads = -1 if num_heads == -1: num_heads = 12 @@ -772,6 +786,7 @@ class Trellis2(nn.Module): coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") is_512_run = False + timestep = timestep.to(self.dtype) if mode == "shape_generation_512": is_512_run = True mode = "shape_generation" diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index cd37ccd30..30f902868 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -962,13 +962,17 @@ def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: feats = input.feats.unbind(dim) return [input.replace(f) for f in feats] -class SparseLinear(nn.Linear): - def __init__(self, in_features, out_features, bias=True): - super(SparseLinear, self).__init__(in_features, out_features, bias) +# allow operations.Linear inheritance +class SparseLinear: + def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None, operations=nn, *args, **kwargs): + class _SparseLinear(operations.Linear): + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): + super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) - def forward(self, input: VarLenTensor) -> VarLenTensor: - return input.replace(super().forward(input.feats)) + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + return _SparseLinear(in_features, out_features, bias=bias, device=device, dtype=dtype, *args, **kwargs) MIX_PRECISION_MODULES = ( nn.Conv1d, diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 56ce4f5ea..1bf7c55b8 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -481,21 +481,25 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) -def simplify_fn(vertices, faces, target=100000): - is_batched = vertices.ndim == 3 - if is_batched: - v_list, f_list = [], [] +def simplify_fn(vertices, faces, colors=None, target=100000): + if vertices.ndim == 3: + v_list, f_list, c_list = [], [], [] for i in range(vertices.shape[0]): - v_i, f_i = simplify_fn(vertices[i], faces[i], target) + c_in = colors[i] if colors is not None else None + v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target) v_list.append(v_i) f_list.append(f_i) - return torch.stack(v_list), torch.stack(f_list) + if c_i is not None: + c_list.append(c_i) + + c_out = torch.stack(c_list) if len(c_list) > 0 else None + return torch.stack(v_list), torch.stack(f_list), c_out if faces.shape[0] <= target: - return vertices, faces + return vertices, faces, colors device = vertices.device - target_v = target / 2.0 + target_v = max(target / 4.0, 1.0) min_v = vertices.min(dim=0)[0] max_v = vertices.max(dim=0)[0] @@ -510,14 +514,17 @@ def simplify_fn(vertices, faces, target=100000): new_vertices = torch.zeros((num_cells, 3), dtype=vertices.dtype, device=device) counts = torch.zeros((num_cells, 1), dtype=vertices.dtype, device=device) - new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices) counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1])) - new_vertices = new_vertices / counts.clamp(min=1) - new_faces = inverse_indices[faces] + new_colors = None + if colors is not None: + new_colors = torch.zeros((num_cells, colors.shape[1]), dtype=colors.dtype, device=device) + new_colors.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, colors.shape[1]), colors) + new_colors = new_colors / counts.clamp(min=1) + new_faces = inverse_indices[faces] valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \ (new_faces[:, 1] != new_faces[:, 2]) & \ (new_faces[:, 2] != new_faces[:, 0]) @@ -527,7 +534,10 @@ def simplify_fn(vertices, faces, target=100000): final_vertices = new_vertices[unique_face_indices] final_faces = inv_face.reshape(-1, 3) - return final_vertices, final_faces + # assign colors + final_colors = new_colors[unique_face_indices] if new_colors is not None else None + + return final_vertices, final_faces, final_colors def fill_holes_fn(vertices, faces, max_perimeter=0.03): is_batched = vertices.ndim == 3 @@ -610,19 +620,6 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): return v, f -def make_double_sided(vertices, faces): - is_batched = vertices.ndim == 3 - if is_batched: - f_list =[] - for i in range(faces.shape[0]): - f_inv = faces[i][:,[0, 2, 1]] - f_list.append(torch.cat([faces[i], f_inv], dim=0)) - return vertices, torch.stack(f_list) - - faces_inv = faces[:, [0, 2, 1]] - faces_double = torch.cat([faces, faces_inv], dim=0) - return vertices, faces_double - class PostProcessMesh(IO.ComfyNode): @classmethod def define_schema(cls): @@ -641,19 +638,23 @@ class PostProcessMesh(IO.ComfyNode): @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): + # TODO: batched mode may break mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces + colors = None + if hasattr(mesh, "colors"): + colors = mesh.colors if fill_holes_perimeter > 0: verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter) if simplify > 0 and faces.shape[0] > simplify: - verts, faces = simplify_fn(verts, faces, target=simplify) - - verts, faces = make_double_sided(verts, faces) + verts, faces, colors = simplify_fn(verts, faces, target=simplify, colors=colors) mesh.vertices = verts mesh.faces = faces + if colors is not None: + mesh.colors = None return IO.NodeOutput(mesh) class Trellis2Extension(ComfyExtension):