diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index ea069a465..047e785ff 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -2,7 +2,7 @@ import math import torch -from typing import Dict, Callable +from typing import Callable import logging NO_TRITON = False @@ -201,13 +201,13 @@ class TorchHashMap: def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): device = keys.device # use long for searchsorted - self.sorted_keys, order = torch.sort(keys.long()) - self.sorted_vals = values.long()[order] + self.sorted_keys, order = torch.sort(keys.to(torch.long)) + self.sorted_vals = values.to(torch.long)[order] self.default_value = torch.tensor(default_value, dtype=torch.long, device=device) self._n = self.sorted_keys.numel() def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: - flat = flat_keys.long() + flat = flat_keys.to(torch.long) if self._n == 0: return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) idx = torch.searchsorted(self.sorted_keys, flat) @@ -225,44 +225,35 @@ def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map): device = neighbor_map.device N, V = neighbor_map.shape + sentinel = UINT32_SENTINEL - neigh = neighbor_map.to(torch.long) - sentinel = torch.tensor(UINT32_SENTINEL, dtype=torch.long, device=device) - - - neigh_map_T = neigh.t().reshape(-1) - + neigh_map_T = neighbor_map.t().reshape(-1) neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32) - mask = (neigh != sentinel).to(torch.long) + mask = (neighbor_map != sentinel).to(torch.long) + gray_code = torch.zeros(N, dtype=torch.long, device=device) - powers = (1 << torch.arange(V, dtype=torch.long, device=device)) + for v in range(V): + gray_code |= (mask[:, v] << v) - gray_long = (mask * powers).sum(dim=1) - - gray_code = gray_long.to(torch.int32) - - binary_long = gray_long.clone() + binary_code = gray_code.clone() for v in range(1, V): - binary_long ^= (gray_long >> v) - binary_code = binary_long.to(torch.int32) + binary_code ^= (gray_code >> v) sorted_idx = torch.argsort(binary_code) - prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T.to(torch.int32), dim=0) # (V*N,) + prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T, dim=0) total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0 if total_valid_signal > 0: + pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0] + to = (prefix_sum_neighbor_mask[pos] - 1).long() + valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device) valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device) - pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0] - - to = (prefix_sum_neighbor_mask[pos] - 1).to(torch.long) - valid_signal_i[to] = (pos % N).to(torch.long) - valid_signal_o[to] = neigh_map_T[pos].to(torch.long) else: valid_signal_i = torch.empty((0,), dtype=torch.long, device=device) @@ -272,9 +263,7 @@ def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map): seg[0] = 0 if V > 0: idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1 - seg[1:] = prefix_sum_neighbor_mask[idxs].to(torch.long) - else: - pass + seg[1:] = prefix_sum_neighbor_mask[idxs] return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg @@ -295,40 +284,41 @@ def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor: def neighbor_map_post_process_for_masked_implicit_gemm_2( - gray_code: torch.Tensor, # [N], int32-like (non-negative) - sorted_idx: torch.Tensor, # [N], long (indexing into gray_code) + gray_code: torch.Tensor, + sorted_idx: torch.Tensor, block_size: int ): device = gray_code.device N = gray_code.numel() - - # num of blocks (same as CUDA) num_blocks = (N + block_size - 1) // block_size - # Ensure dtypes - gray_long = gray_code.to(torch.int64) # safer to OR in 64-bit then cast - sorted_idx = sorted_idx.to(torch.long) - - # 1) Group gray_code by blocks and compute OR across each block - # pad the last block with zeros if necessary so we can reshape pad = num_blocks * block_size - N if pad > 0: - pad_vals = torch.zeros((pad,), dtype=torch.int64, device=device) - gray_padded = torch.cat([gray_long[sorted_idx], pad_vals], dim=0) + pad_vals = torch.zeros((pad,), dtype=torch.int32, device=device) + gray_padded = torch.cat([gray_code[sorted_idx], pad_vals], dim=0) else: - gray_padded = gray_long[sorted_idx] + gray_padded = gray_code[sorted_idx] - # reshape to (num_blocks, block_size) and compute bitwise_or across dim=1 - gray_blocks = gray_padded.view(num_blocks, block_size) # each row = block entries - # reduce with bitwise_or - reduced_code = gray_blocks[:, 0].clone() - for i in range(1, block_size): - reduced_code |= gray_blocks[:, i] - reduced_code = reduced_code.to(torch.int32) # match CUDA int32 + gray_blocks = gray_padded.view(num_blocks, block_size) + + reduced_code = gray_blocks + while reduced_code.shape[1] > 1: + half = reduced_code.shape[1] // 2 + remainder = reduced_code.shape[1] % 2 + + left = reduced_code[:, :half * 2:2] + right = reduced_code[:, 1:half * 2:2] + merged = left | right + + if remainder: + reduced_code = torch.cat([merged, reduced_code[:, -1:]], dim=1) + else: + reduced_code = merged + + reduced_code = reduced_code.squeeze(1) + + seglen_counts = _popcount_int32_tensor(reduced_code).to(torch.int32) - # 2) compute seglen (popcount per reduced_code) and seg (prefix sum) - seglen_counts = _popcount_int32_tensor(reduced_code.to(torch.int64)).to(torch.int32) # [num_blocks] - # seg: length num_blocks+1, seg[0]=0, seg[i+1]=cumsum(seglen_counts) up to i seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device) seg[0] = 0 if num_blocks > 0: @@ -336,30 +326,20 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2( total = int(seg[-1].item()) - # 3) scatter — produce valid_kernel_idx as concatenated ascending set-bit positions for each reduced_code row if total == 0: - valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device) - return valid_kernel_idx, seg + return torch.empty((0,), dtype=torch.int32, device=device), seg - max_val = int(reduced_code.max().item()) - V = max_val.bit_length() if max_val > 0 else 0 - # If you know V externally, pass it instead or set here explicitly. + V = int(reduced_code.max().item()).bit_length() if reduced_code.max() > 0 else 0 if V == 0: - # no bits set anywhere - valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device) - return valid_kernel_idx, seg + return torch.empty((0,), dtype=torch.int32, device=device), seg - # build mask of shape (num_blocks, V): True where bit is set - bit_pos = torch.arange(0, V, dtype=torch.int64, device=device) # [V] - # shifted = reduced_code[:, None] >> bit_pos[None, :] - shifted = reduced_code.to(torch.int64).unsqueeze(1) >> bit_pos.unsqueeze(0) - bits = (shifted & 1).to(torch.bool) # (num_blocks, V) + bit_pos = torch.arange(0, V, dtype=torch.int32, device=device) + shifted = reduced_code.unsqueeze(1) >> bit_pos.unsqueeze(0) + bits = (shifted & 1).to(torch.bool) positions = bit_pos.unsqueeze(0).expand(num_blocks, V) - - valid_positions = positions[bits] - valid_kernel_idx = valid_positions.to(torch.int32).contiguous() + valid_kernel_idx = positions[bits].to(torch.int32).contiguous() return valid_kernel_idx, seg @@ -425,35 +405,6 @@ def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache return out, neighbor -class Voxel: - def __init__( - self, - origin: list, - voxel_size: float, - coords: torch.Tensor = None, - attrs: torch.Tensor = None, - layout = None, - device: torch.device = None - ): - if layout is None: - layout = {} - self.origin = torch.tensor(origin, dtype=torch.float32, device=device) - self.voxel_size = voxel_size - self.coords = coords - self.attrs = attrs - self.layout = layout - self.device = device - - @property - def position(self): - return (self.coords + 0.5) * self.voxel_size + self.origin[None, :] - - def split_attrs(self): - return { - k: self.attrs[:, self.layout[k]] - for k in self.layout - } - class Mesh: def __init__(self, vertices, @@ -480,35 +431,3 @@ class Mesh: def cpu(self): return self.to('cpu') - -class MeshWithVoxel(Mesh, Voxel): - def __init__(self, - vertices: torch.Tensor, - faces: torch.Tensor, - origin: list, - voxel_size: float, - coords: torch.Tensor, - attrs: torch.Tensor, - voxel_shape: torch.Size, - layout: Dict = {}, - ): - self.vertices = vertices.float() - self.faces = faces.int() - self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device) - self.voxel_size = voxel_size - self.coords = coords - self.attrs = attrs - self.voxel_shape = voxel_shape - self.layout = layout - - def to(self, device, non_blocking=False): - return MeshWithVoxel( - self.vertices.to(device, non_blocking=non_blocking), - self.faces.to(device, non_blocking=non_blocking), - self.origin.tolist(), - self.voxel_size, - self.coords.to(device, non_blocking=non_blocking), - self.attrs.to(device, non_blocking=non_blocking), - self.voxel_shape, - self.layout, - ) diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 0b1975092..2a18c496a 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from fractions import Fraction from dataclasses import dataclass from typing import List, Any, Dict, Optional, overload, Union, Tuple -from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d +from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, sparse_submanifold_conv3d def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: @@ -210,6 +210,8 @@ class SparseResBlockC2S3d(nn.Module): def forward(self, x, subdiv = None): if self.pred_subdiv: + dtype = next(self.to_subdiv.parameters()).dtype + x = x.to(dtype) subdiv = self.to_subdiv(x) norm1 = self.norm1.to(torch.float32) norm2 = self.norm2.to(torch.float32) @@ -987,114 +989,7 @@ def convert_module_to_f16(l): for p in l.parameters(): p.data = p.data.half() - - -class SparseUnetVaeEncoder(nn.Module): - """ - Sparse Swin Transformer Unet VAE model. - """ - def __init__( - self, - in_channels: int, - model_channels: List[int], - latent_channels: int, - num_blocks: List[int], - block_type: List[str], - down_block_type: List[str], - block_args: List[Dict[str, Any]], - use_fp16: bool = False, - ): - super().__init__() - self.in_channels = in_channels - self.model_channels = model_channels - self.num_blocks = num_blocks - self.dtype = torch.float16 if use_fp16 else torch.float32 - - self.input_layer = SparseLinear(in_channels, model_channels[0]) - self.to_latent = SparseLinear(model_channels[-1], 2 * latent_channels) - - self.blocks = nn.ModuleList([]) - for i in range(len(num_blocks)): - self.blocks.append(nn.ModuleList([])) - for j in range(num_blocks[i]): - self.blocks[-1].append( - globals()[block_type[i]]( - model_channels[i], - **block_args[i], - ) - ) - if i < len(num_blocks) - 1: - self.blocks[-1].append( - globals()[down_block_type[i]]( - model_channels[i], - model_channels[i+1], - **block_args[i], - ) - ) - - @property - def device(self) -> torch.device: - return next(self.parameters()).device - - def forward(self, x: SparseTensor, sample_posterior=False, return_raw=False): - h = self.input_layer(x) - h = h.type(self.dtype) - for i, res in enumerate(self.blocks): - for j, block in enumerate(res): - h = block(h) - h = h.type(x.dtype) - h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) - h = self.to_latent(h) - - # Sample from the posterior distribution - mean, logvar = h.feats.chunk(2, dim=-1) - if sample_posterior: - std = torch.exp(0.5 * logvar) - z = mean + std * torch.randn_like(std) - else: - z = mean - z = h.replace(z) - - if return_raw: - return z, mean, logvar - else: - return z - - - -class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder): - def __init__( - self, - model_channels: List[int], - latent_channels: int, - num_blocks: List[int], - block_type: List[str], - down_block_type: List[str], - block_args: List[Dict[str, Any]], - use_fp16: bool = False, - ): - super().__init__( - 6, - model_channels, - latent_channels, - num_blocks, - block_type, - down_block_type, - block_args, - use_fp16, - ) - - def forward(self, vertices: SparseTensor, intersected: SparseTensor, sample_posterior=False, return_raw=False): - x = vertices.replace(torch.cat([ - vertices.feats - 0.5, - intersected.feats.float() - 0.5, - ], dim=1)) - return super().forward(x, sample_posterior, return_raw) - class SparseUnetVaeDecoder(nn.Module): - """ - Sparse Swin Transformer Unet VAE model. - """ def __init__( self, out_channels: int, @@ -1218,10 +1113,10 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder): N = coords.shape[0] # compute flat keys for all coords (prepend batch 0 same as original code) b = torch.zeros((N,), dtype=torch.long, device=device) - x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long() + x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32) W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) flat_keys = b * (W * H * D) + x * (H * D) + y * D + z - values = torch.arange(N, dtype=torch.long, device=device) + values = torch.arange(N, dtype=torch.int32, device=device) DEFAULT_VAL = 0xffffffff # sentinel used in original code return TorchHashMap(flat_keys, values, DEFAULT_VAL) @@ -1295,13 +1190,12 @@ def flexible_dual_grid_to_mesh( # Extract mesh N = dual_vertices.shape[0] - mesh_vertices = (coords.float() + dual_vertices) / (2 * N) - 0.5 if hashmap_builder is None: # build local TorchHashMap device = coords.device b = torch.zeros((N,), dtype=torch.long, device=device) - x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long() + x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32) W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) flat_keys = b * (W * H * D) + x * (H * D) + y * D + z values = torch.arange(N, dtype=torch.long, device=device) @@ -1316,9 +1210,9 @@ def flexible_dual_grid_to_mesh( M = connected_voxel.shape[0] # flatten connected voxel coords and lookup conn_flat_b = torch.zeros((M * 4,), dtype=torch.long, device=coords.device) - conn_x = connected_voxel.reshape(-1, 3)[:, 0].long() - conn_y = connected_voxel.reshape(-1, 3)[:, 1].long() - conn_z = connected_voxel.reshape(-1, 3)[:, 2].long() + conn_x = connected_voxel.reshape(-1, 3)[:, 0].to(torch.int32) + conn_y = connected_voxel.reshape(-1, 3)[:, 1].to(torch.int32) + conn_z = connected_voxel.reshape(-1, 3)[:, 2].to(torch.int32) W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) conn_flat = conn_flat_b * (W * H * D) + conn_x * (H * D) + conn_y * D + conn_z @@ -1526,17 +1420,18 @@ class Vae(nn.Module): channels=[512, 128, 32], ) + @torch.no_grad() def decode_shape_slat(self, slat, resolution: int): self.shape_dec.set_resolution(resolution) - device = comfy.model_management.get_torch_device() - self.shape_dec = self.shape_dec.to(device) return self.shape_dec(slat, return_subs=True) + @torch.no_grad() def decode_tex_slat(self, slat, subs): if self.txt_dec is None: raise ValueError("Checkpoint doesn't include texture model") return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5 + # shouldn't be called (placeholder) @torch.no_grad() def decode( self, @@ -1546,17 +1441,4 @@ class Vae(nn.Module): ): meshes, subs = self.decode_shape_slat(shape_slat, resolution) tex_voxels = self.decode_tex_slat(tex_slat, subs) - out_mesh = [] - for m, v in zip(meshes, tex_voxels): - out_mesh.append( - MeshWithVoxel( - m.vertices, m.faces, - origin = [-0.5, -0.5, -0.5], - voxel_size = 1 / resolution, - coords = v.coords[:, 1:], - attrs = v.feats, - voxel_shape = torch.Size([*v.shape, *v.spatial_shape]), - layout=self.pbr_attr_layout - ) - ) - return out_mesh + return tex_voxels diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 96510e916..e781d35e3 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,7 +1,7 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy.ldm.trellis2.vae import SparseTensor -from comfy.utils import ProgressBar +from comfy.utils import ProgressBar, lanczos import torch.nn.functional as TF import comfy.model_management from PIL import Image @@ -102,9 +102,7 @@ def run_conditioning(model, image, mask, include_1024 = True, background_color = cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb) def prepare_tensor(img, size): - resized = torch.nn.functional.interpolate( - img.unsqueeze(0), size=(size, size), mode='bicubic', align_corners=False - ) + resized = lanczos(img.unsqueeze(0), size, size) return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) model_internal.image_size = 512 @@ -148,10 +146,16 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, structure_output, vae, resolution): + + patcher = vae.patcher + device = comfy.model_management.get_torch_device() + comfy.model_management.load_model_gpu(patcher) + vae = vae.first_stage_model decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] + samples = samples.squeeze(-1).transpose(1, 2).to(device) std = shape_slat_normalization["std"].to(samples) mean = shape_slat_normalization["mean"].to(samples) samples = SparseTensor(feats = samples, coords=coords) @@ -179,10 +183,16 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, structure_output, vae, shape_subs): + + patcher = vae.patcher + device = comfy.model_management.get_torch_device() + comfy.model_management.load_model_gpu(patcher) + vae = vae.first_stage_model decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] + samples = samples.squeeze(-1).transpose(1, 2).to(device) std = tex_slat_normalization["std"].to(samples) mean = tex_slat_normalization["mean"].to(samples) samples = SparseTensor(feats = samples, coords=coords)