From d6573fd26d63e6fe00515e903d4c2535b02fbeea Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:10:20 +0200 Subject: [PATCH] model init working --- comfy/latent_formats.py | 2 + comfy/ldm/trellis2/cumesh.py | 387 ++++++++++++++++++++++++++++++++++- comfy/ldm/trellis2/model.py | 20 +- comfy/ldm/trellis2/vae.py | 221 +++++++++++++++++++- comfy/model_base.py | 8 + comfy/supported_models.py | 13 +- 6 files changed, 639 insertions(+), 12 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 4b3a3798c..fc4c4e6d3 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -746,6 +746,8 @@ class Hunyuan3Dv2_1(LatentFormat): latent_channels = 64 latent_dimensions = 1 +class Trellis2(LatentFormat): # TODO + latent_channels = 32 class Hunyuan3Dv2mini(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index be8200341..41ac35db9 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -1,8 +1,192 @@ # will contain every cuda -> pytorch operation +import math import torch -from typing import Dict +from typing import Dict, Callable +NO_TRITION = False +try: + import triton + import triton.language as tl + heuristics = { + 'valid_kernel': lambda args: args['valid_kernel'](args['B1']), + 'valid_kernel_seg': lambda args: args['valid_kernel_seg'](args['B1']), + } + + #@triton_autotune( + # configs=config.autotune_config, + # key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'], + #) + @triton.heuristics(heuristics) + @triton.jit + def sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel( + input, + weight, + bias, + neighbor, + sorted_idx, + output, + # Tensor dimensions + N, LOGN, Ci, Co, V: tl.constexpr, + # Meta-parameters + B1: tl.constexpr, # Block size for N dimension + B2: tl.constexpr, # Block size for Co dimension + BK: tl.constexpr, # Block size for K dimension (V * Ci) + allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls + # Huristic parameters + valid_kernel, + valid_kernel_seg, + ): + + block_id = tl.program_id(axis=0) + block_dim_co = tl.cdiv(Co, B2) + block_id_co = block_id % block_dim_co + block_id_n = block_id // block_dim_co + + # Create pointers for submatrices of A and B. + num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension + valid_kernel_start = tl.load(valid_kernel_seg + block_id_n) + valid_kernel_seglen = tl.load(valid_kernel_seg + block_id_n + 1) - valid_kernel_start + offset_n = block_id_n * B1 + tl.arange(0, B1) + n_mask = offset_n < N + offset_sorted_n = tl.load(sorted_idx + offset_n, mask=n_mask, other=0) # (B1,) + offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,) + offset_k = tl.arange(0, BK) # (BK,) + + # Create a block of the output matrix C. + accumulator = tl.zeros((B1, B2), dtype=tl.float32) + + # Iterate along V*Ci dimension. + for k in range(num_k * valid_kernel_seglen): + v = k // num_k + bk = k % num_k + v = tl.load(valid_kernel + valid_kernel_start + v) + # Calculate pointers to input matrix. + neighbor_offset_n = tl.load(neighbor + offset_sorted_n * V + v) # (B1,) + input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK) + # Calculate pointers to weight matrix. + weight_ptr = weight + v * Ci + bk * BK + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2) + # Load the next block of input and weight. + neigh_mask = neighbor_offset_n != 0xffffffff + k_mask = offset_k < Ci - bk * BK + input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0) + weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0) + # Accumulate along the K dimension. + accumulator = tl.dot(input_block, weight_block, accumulator, + input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2) + c = accumulator.to(input.type.element_ty) + + # add bias + if bias is not None: + bias_block = tl.load(bias + offset_co) + c += bias_block[None, :] + + # Write back the block of the output matrix with masks. + out_offset_n = offset_sorted_n + out_offset_co = block_id_co * B2 + tl.arange(0, B2) + out_ptr = output + (out_offset_n[:, None] * Co + out_offset_co[None, :]) + out_mask = n_mask[:, None] & (out_offset_co[None, :] < Co) + tl.store(out_ptr, c, mask=out_mask) + def sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + neighbor: torch.Tensor, + sorted_idx: torch.Tensor, + valid_kernel: Callable[[int], torch.Tensor], + valid_kernel_seg: Callable[[int], torch.Tensor], + ) -> torch.Tensor: + N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1] + LOGN = int(math.log2(N)) + output = torch.empty((N, Co), device=input.device, dtype=input.dtype) + grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),) + sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid]( + input, weight, bias, neighbor, sorted_idx, output, + N, LOGN, Ci, Co, V, # + valid_kernel=valid_kernel, + valid_kernel_seg=valid_kernel_seg, + allow_tf32=torch.cuda.is_tf32_supported(), + ) + return output +except: + NO_TRITION = True + +def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device): + # offsets in same order as CUDA kernel + offsets = [] + for vx in range(Kw): + for vy in range(Kh): + for vz in range(Kd): + offsets.append(( + vx * Dw, + vy * Dh, + vz * Dd + )) + return torch.tensor(offsets, device=device) + +def build_submanifold_neighbor_map( + hashmap, + coords: torch.Tensor, + W, H, D, + Kw, Kh, Kd, + Dw, Dh, Dd, +): + device = coords.device + M = coords.shape[0] + V = Kw * Kh * Kd + half_V = V // 2 + 1 + + INVALID = hashmap.default_value + + neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long) + + b = coords[:, 0] + x = coords[:, 1] + y = coords[:, 2] + z = coords[:, 3] + + offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device) + + ox = x[:, None] - (Kw // 2) * Dw + oy = y[:, None] - (Kh // 2) * Dh + oz = z[:, None] - (Kd // 2) * Dd + + for v in range(half_V): + if v == half_V - 1: + neighbor[:, v] = torch.arange(M, device=device) + continue + + dx, dy, dz = offsets[v] + + kx = ox[:, v] + dx + ky = oy[:, v] + dy + kz = oz[:, v] + dz + + valid = ( + (kx >= 0) & (kx < W) & + (ky >= 0) & (ky < H) & + (kz >= 0) & (kz < D) + ) + + flat = ( + b * (W * H * D) + + kx * (H * D) + + ky * D + + kz + ) + + flat = flat[valid] + idx = torch.nonzero(valid, as_tuple=False).squeeze(1) + + found = hashmap.lookup_flat(flat) + + neighbor[idx, v] = found + + # symmetric write + valid_found = found != INVALID + neighbor[found[valid_found], V - 1 - v] = idx[valid_found] + + return neighbor class TorchHashMap: def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): @@ -22,6 +206,207 @@ class TorchHashMap: out[found] = self.sorted_vals[idx[found]] return out + +UINT32_SENTINEL = 0xFFFFFFFF + +def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map): + device = neighbor_map.device + N, V = neighbor_map.shape + + + neigh = neighbor_map.to(torch.long) + sentinel = torch.tensor(UINT32_SENTINEL, dtype=torch.long, device=device) + + + neigh_map_T = neigh.t().reshape(-1) + + neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32) + + mask = (neigh != sentinel).to(torch.long) + + powers = (1 << torch.arange(V, dtype=torch.long, device=device)) + + gray_long = (mask * powers).sum(dim=1) + + gray_code = gray_long.to(torch.int32) + + binary_long = gray_long.clone() + for v in range(1, V): + binary_long ^= (gray_long >> v) + binary_code = binary_long.to(torch.int32) + + sorted_idx = torch.argsort(binary_code) + + prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T.to(torch.int32), dim=0) # (V*N,) + + total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0 + + if total_valid_signal > 0: + valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device) + valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device) + + pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0] + + to = (prefix_sum_neighbor_mask[pos] - 1).to(torch.long) + + valid_signal_i[to] = (pos % N).to(torch.long) + + valid_signal_o[to] = neigh_map_T[pos].to(torch.long) + else: + valid_signal_i = torch.empty((0,), dtype=torch.long, device=device) + valid_signal_o = torch.empty((0,), dtype=torch.long, device=device) + + seg = torch.empty((V + 1,), dtype=torch.long, device=device) + seg[0] = 0 + if V > 0: + idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1 + seg[1:] = prefix_sum_neighbor_mask[idxs].to(torch.long) + else: + pass + + return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg + +def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor: + + x = x.to(torch.int64) + + m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device) + m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device) + m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device) + h01 = torch.tensor(0x0101010101010101, dtype=torch.int64, device=x.device) + + x = x - ((x >> 1) & m1) + x = (x & m2) + ((x >> 2) & m2) + x = (x + (x >> 4)) & m4 + x = (x * h01) >> 56 + return x.to(torch.int32) + + +def neighbor_map_post_process_for_masked_implicit_gemm_2( + gray_code: torch.Tensor, # [N], int32-like (non-negative) + sorted_idx: torch.Tensor, # [N], long (indexing into gray_code) + block_size: int +): + device = gray_code.device + N = gray_code.numel() + + # num of blocks (same as CUDA) + num_blocks = (N + block_size - 1) // block_size + + # Ensure dtypes + gray_long = gray_code.to(torch.int64) # safer to OR in 64-bit then cast + sorted_idx = sorted_idx.to(torch.long) + + # 1) Group gray_code by blocks and compute OR across each block + # pad the last block with zeros if necessary so we can reshape + pad = num_blocks * block_size - N + if pad > 0: + pad_vals = torch.zeros((pad,), dtype=torch.int64, device=device) + gray_padded = torch.cat([gray_long[sorted_idx], pad_vals], dim=0) + else: + gray_padded = gray_long[sorted_idx] + + # reshape to (num_blocks, block_size) and compute bitwise_or across dim=1 + gray_blocks = gray_padded.view(num_blocks, block_size) # each row = block entries + # reduce with bitwise_or + reduced_code = gray_blocks[:, 0].clone() + for i in range(1, block_size): + reduced_code |= gray_blocks[:, i] + reduced_code = reduced_code.to(torch.int32) # match CUDA int32 + + # 2) compute seglen (popcount per reduced_code) and seg (prefix sum) + seglen_counts = _popcount_int32_tensor(reduced_code.to(torch.int64)).to(torch.int32) # [num_blocks] + # seg: length num_blocks+1, seg[0]=0, seg[i+1]=cumsum(seglen_counts) up to i + seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device) + seg[0] = 0 + if num_blocks > 0: + seg[1:] = torch.cumsum(seglen_counts, dim=0) + + total = int(seg[-1].item()) + + # 3) scatter — produce valid_kernel_idx as concatenated ascending set-bit positions for each reduced_code row + if total == 0: + valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device) + return valid_kernel_idx, seg + + max_val = int(reduced_code.max().item()) + V = max_val.bit_length() if max_val > 0 else 0 + # If you know V externally, pass it instead or set here explicitly. + + if V == 0: + # no bits set anywhere + valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device) + return valid_kernel_idx, seg + + # build mask of shape (num_blocks, V): True where bit is set + bit_pos = torch.arange(0, V, dtype=torch.int64, device=device) # [V] + # shifted = reduced_code[:, None] >> bit_pos[None, :] + shifted = reduced_code.to(torch.int64).unsqueeze(1) >> bit_pos.unsqueeze(0) + bits = (shifted & 1).to(torch.bool) # (num_blocks, V) + + positions = bit_pos.unsqueeze(0).expand(num_blocks, V) + + valid_positions = positions[bits] + valid_kernel_idx = valid_positions.to(torch.int32).contiguous() + + return valid_kernel_idx, seg + + +def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation): + if len(shape) == 5: + N, C, W, H, D = shape + else: + W, H, D = shape + + Co, Kw, Kh, Kd, Ci = weight.shape + + b_stride = W * H * D + x_stride = H * D + y_stride = D + z_stride = 1 + + flat_keys = (coords[:, 0].long() * b_stride + + coords[:, 1].long() * x_stride + + coords[:, 2].long() * y_stride + + coords[:, 3].long() * z_stride) + + vals = torch.arange(coords.shape[0], dtype=torch.int32, device=coords.device) + + hashmap = TorchHashMap(flat_keys, vals, 0xFFFFFFFF) + + if neighbor_cache is None: + neighbor = build_submanifold_neighbor_map( + hashmap, coords, W, H, D, Kw, Kh, Kd, + dilation[0], dilation[1], dilation[2] + ) + else: + neighbor = neighbor_cache + + block_size = 128 + + gray_code, sorted_idx, valid_signal_i, valid_signal_o, valid_signal_seg = \ + neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor) + + valid_kernel, valid_kernel_seg = \ + neighbor_map_post_process_for_masked_implicit_gemm_2(gray_code, sorted_idx, block_size) + + valid_kernel_fn = lambda b_size: valid_kernel + valid_kernel_seg_fn = lambda b_size: valid_kernel_seg + + weight_flat = weight.contiguous().view(Co, -1, Ci) + + out = sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk( + feats, + weight_flat, + bias, + neighbor, + sorted_idx, + valid_kernel_fn, + valid_kernel_seg_fn + ) + + return out, neighbor + class Voxel: def __init__( self, diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index cdbfbf6fc..9d1a8fdb4 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -408,7 +408,7 @@ class SLatFlowModel(nn.Module): self.qk_rms_norm_cross = qk_rms_norm_cross self.dtype = dtype - self.t_embedder = TimestepEmbedder(model_channels) + self.t_embedder = TimestepEmbedder(model_channels, device=device, dtype=dtype, operations=operations) if share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), @@ -485,15 +485,25 @@ class Trellis2(nn.Module): qk_rms_norm = True, qk_rms_norm_cross = True, dtype=None, device=None, operations=None): + + super().__init__() args = { "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations } # TODO: update the names/checkpoints - self.img2shape = SLatFlowModel(resolution, in_channels=in_channels, *args) - self.shape2txt = SLatFlowModel(resolution, in_channels=in_channels*2, *args) + self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) + self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) self.shape_generation = True - def forward(self, x, timestep, context): - pass + def forward(self, x, timestep, context, **kwargs): + # TODO add mode + mode = kwargs.get("mode", "shape_generation") + mode = "texture_generation" if mode == 1 else "shape_generation" + if mode == "shape_generation": + out = self.img2shape(x, timestep, context) + if mode == "texture_generation": + out = self.shape2txt(x, timestep, context) + + return out diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 1d564bca2..5dabf5246 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -1,3 +1,4 @@ +import math import torch import torch.nn as nn from typing import List, Any, Dict, Optional, overload, Union, Tuple @@ -5,12 +6,219 @@ from fractions import Fraction import torch.nn.functional as F from dataclasses import dataclass import numpy as np -from cumesh import TorchHashMap, Mesh, MeshWithVoxel +from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d + + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key) + + def forward(self, x): + return sparse_conv3d_forward(self, x) + + +def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size, ) * 3 + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, ) * 3 + self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation, ) * 3 + + self.weight = nn.Parameter(torch.empty((out_channels, in_channels, *self.kernel_size))) + if bias: + self.bias = nn.Parameter(torch.empty(out_channels)) + else: + self.register_parameter("bias", None) + + if self.bias is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + torch.nn.init.uniform_(self.bias, -bound, bound) + + # Permute weight (Co, Ci, Kd, Kh, Kw) -> (Co, Kd, Kh, Kw, Ci) + self.weight = nn.Parameter(self.weight.permute(0, 2, 3, 4, 1).contiguous()) + + +def sparse_conv3d_forward(self, x): + # check if neighbor map is already computed + Co, Kd, Kh, Kw, Ci = self.weight.shape + neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}' + neighbor_cache = x.get_spatial_cache(neighbor_cache_key) + + out, neighbor_cache_ = sparse_submanifold_conv3d( + x.feats, + x.coords, + torch.Size([*x.shape, *x.spatial_shape]), + self.weight, + self.bias, + neighbor_cache, + self.dilation + ) + + if neighbor_cache is None: + x.register_spatial_cache(neighbor_cache_key, neighbor_cache_) + + out = x.replace(out) + return out + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = x.to(torch.float32) + o = super().forward(x) + return o.to(x_dtype) + +class SparseConvNeXtBlock3d(nn.Module): + def __init__( + self, + channels: int, + mlp_ratio: float = 4.0, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.use_checkpoint = use_checkpoint + + self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.conv = SparseConv3d(channels, channels, 3) + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.SiLU(), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def _forward(self, x): + h = self.conv(x) + h = h.replace(self.norm(h.feats)) + h = h.replace(self.mlp(h.feats)) + return h + x + + def forward(self, x): + return self._forward(x) + +class SparseSpatial2Channel(nn.Module): + def __init__(self, factor: int = 2): + super(SparseSpatial2Channel, self).__init__() + self.factor = factor + + def forward(self, x): + DIM = x.coords.shape[-1] - 1 + cache = x.get_spatial_cache(f'spatial2channel_{self.factor}') + if cache is None: + coord = list(x.coords.unbind(dim=-1)) + for i in range(DIM): + coord[i+1] = coord[i+1] // self.factor + subidx = x.coords[:, 1:] % self.factor + subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)]) + + MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + else: + new_coords, idx, subidx = cache + + new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype) + new_feats[idx * self.factor ** DIM + subidx] = x.feats + + out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM])) + out._scale = tuple([s * self.factor for s in x._scale]) + out._spatial_cache = x._spatial_cache + + if cache is None: + x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx)) + out.register_spatial_cache(f'channel2spatial_{self.factor}', (x.coords, idx, subidx)) + out.register_spatial_cache('shape', torch.Size(MAX)) + + return out + +class SparseChannel2Spatial(nn.Module): + def __init__(self, factor: int = 2): + super(SparseChannel2Spatial, self).__init__() + self.factor = factor + + def forward(self, x, subdivision = None): + DIM = x.coords.shape[-1] - 1 + + cache = x.get_spatial_cache(f'channel2spatial_{self.factor}') + if cache is None: + if subdivision is None: + raise ValueError('Cache not found. Provide subdivision tensor or pair SparseChannel2Spatial with SparseSpatial2Channel.') + else: + sub = subdivision.feats # [N, self.factor ** DIM] + N_leaf = sub.sum(dim=-1) # [N] + subidx = sub.nonzero()[:, -1] + new_coords = x.coords.clone().detach() + new_coords[:, 1:] *= self.factor + new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0]) + for i in range(DIM): + new_coords[:, i+1] += subidx // self.factor ** i % self.factor + idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0]) + else: + new_coords, idx, subidx = cache + + x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1) + new_feats = x_feats[idx * self.factor ** DIM + subidx] + out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM])) + out._scale = tuple([s / self.factor for s in x._scale]) + if cache is not None: # only keep cache when subdiv following it + out._spatial_cache = x._spatial_cache + return out + +class SparseResBlockC2S3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + self.pred_subdiv = pred_subdiv + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = SparseConv3d(channels, self.out_channels * 8, 3) + self.conv2 = SparseConv3d(self.out_channels, self.out_channels, 3) + self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1)) + if pred_subdiv: + self.to_subdiv = SparseLinear(channels, 8) + self.updown = SparseChannel2Spatial(2) + + def _forward(self, x, subdiv = None): + if self.pred_subdiv: + subdiv = self.to_subdiv(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None + h = self.updown(h, subdiv_binarized) + x = self.updown(x, subdiv_binarized) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + if self.pred_subdiv: + return h, subdiv + else: + return h -# TODO: determine which conv they actually use @dataclass class config: - CONV = "none" + CONV = "flexgemm" + FLEX_GEMM_HASHMAP_RATIO = 2.0 # TODO post processing def simplify(self, target_num_faces: int, verbose: bool=False, options: dict={}): @@ -1131,6 +1339,7 @@ def flexible_dual_grid_to_mesh( class Vae(nn.Module): def __init__(self, config, operations=None): + super().__init__() operations = operations or torch.nn self.txt_dec = SparseUnetVaeDecoder( @@ -1139,7 +1348,8 @@ class Vae(nn.Module): latent_channels=32, num_blocks=[4, 16, 8, 4, 0], block_type=["SparseConvNeXtBlock3d"] * 5, - up_block_type=["SparseResBlockS2C3d"] * 4, + up_block_type=["SparseResBlockC2S3d"] * 4, + block_args=[{}, {}, {}, {}, {}], pred_subdiv=False ) @@ -1149,7 +1359,8 @@ class Vae(nn.Module): latent_channels=32, num_blocks=[4, 16, 8, 4, 0], block_type=["SparseConvNeXtBlock3d"] * 5, - up_block_type=["SparseResBlockS2C3d"] * 4, + up_block_type=["SparseResBlockC2S3d"] * 4, + block_args=[{}, {}, {}, {}, {}], ) def decode_shape_slat(self, slat, resolution: int): diff --git a/comfy/model_base.py b/comfy/model_base.py index 85acdb66a..a5fc81c4d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -50,6 +50,7 @@ import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model +import comfy.ldm.trellis2.model import comfy.model_management import comfy.patcher_extension @@ -1455,6 +1456,13 @@ class WAN22(WAN21): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class Trellis2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None, unet_model=comfy.ldm.trellis2.model.Trellis2): + super().__init__(model_config, model_type, device, unet_model) + + def extra_conds(self, **kwargs): + return super().extra_conds(**kwargs) + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index d25271d6e..6c2725b9f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1242,6 +1242,17 @@ class WAN22_T2V(WAN21_T2V): out = model_base.WAN22(self, image_to_video=True, device=device) return out +class Trellis2(supported_models_base.BASE): + unet_config = { + "image_model": "trellis2" + } + + latent_format = latent_formats.Trellis2 + vae_key_prefix = ["vae."] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.Trellis2(self, device=device) + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1596,6 +1607,6 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, Trellis2] models += [SVD_img2vid]