removed unnecessary code + optimizations + progres

This commit is contained in:
Yousef Rafat 2026-02-25 23:54:03 +02:00
parent f31c2e1d1d
commit 39270fdca9
3 changed files with 75 additions and 264 deletions

View File

@ -2,7 +2,7 @@
import math
import torch
from typing import Dict, Callable
from typing import Callable
import logging
NO_TRITON = False
@ -201,13 +201,13 @@ class TorchHashMap:
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
device = keys.device
# use long for searchsorted
self.sorted_keys, order = torch.sort(keys.long())
self.sorted_vals = values.long()[order]
self.sorted_keys, order = torch.sort(keys.to(torch.long))
self.sorted_vals = values.to(torch.long)[order]
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
self._n = self.sorted_keys.numel()
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
flat = flat_keys.long()
flat = flat_keys.to(torch.long)
if self._n == 0:
return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
idx = torch.searchsorted(self.sorted_keys, flat)
@ -225,44 +225,35 @@ def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map):
device = neighbor_map.device
N, V = neighbor_map.shape
sentinel = UINT32_SENTINEL
neigh = neighbor_map.to(torch.long)
sentinel = torch.tensor(UINT32_SENTINEL, dtype=torch.long, device=device)
neigh_map_T = neigh.t().reshape(-1)
neigh_map_T = neighbor_map.t().reshape(-1)
neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32)
mask = (neigh != sentinel).to(torch.long)
mask = (neighbor_map != sentinel).to(torch.long)
gray_code = torch.zeros(N, dtype=torch.long, device=device)
powers = (1 << torch.arange(V, dtype=torch.long, device=device))
for v in range(V):
gray_code |= (mask[:, v] << v)
gray_long = (mask * powers).sum(dim=1)
gray_code = gray_long.to(torch.int32)
binary_long = gray_long.clone()
binary_code = gray_code.clone()
for v in range(1, V):
binary_long ^= (gray_long >> v)
binary_code = binary_long.to(torch.int32)
binary_code ^= (gray_code >> v)
sorted_idx = torch.argsort(binary_code)
prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T.to(torch.int32), dim=0) # (V*N,)
prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T, dim=0)
total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0
if total_valid_signal > 0:
pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0]
to = (prefix_sum_neighbor_mask[pos] - 1).long()
valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0]
to = (prefix_sum_neighbor_mask[pos] - 1).to(torch.long)
valid_signal_i[to] = (pos % N).to(torch.long)
valid_signal_o[to] = neigh_map_T[pos].to(torch.long)
else:
valid_signal_i = torch.empty((0,), dtype=torch.long, device=device)
@ -272,9 +263,7 @@ def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map):
seg[0] = 0
if V > 0:
idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1
seg[1:] = prefix_sum_neighbor_mask[idxs].to(torch.long)
else:
pass
seg[1:] = prefix_sum_neighbor_mask[idxs]
return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg
@ -295,40 +284,41 @@ def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor:
def neighbor_map_post_process_for_masked_implicit_gemm_2(
gray_code: torch.Tensor, # [N], int32-like (non-negative)
sorted_idx: torch.Tensor, # [N], long (indexing into gray_code)
gray_code: torch.Tensor,
sorted_idx: torch.Tensor,
block_size: int
):
device = gray_code.device
N = gray_code.numel()
# num of blocks (same as CUDA)
num_blocks = (N + block_size - 1) // block_size
# Ensure dtypes
gray_long = gray_code.to(torch.int64) # safer to OR in 64-bit then cast
sorted_idx = sorted_idx.to(torch.long)
# 1) Group gray_code by blocks and compute OR across each block
# pad the last block with zeros if necessary so we can reshape
pad = num_blocks * block_size - N
if pad > 0:
pad_vals = torch.zeros((pad,), dtype=torch.int64, device=device)
gray_padded = torch.cat([gray_long[sorted_idx], pad_vals], dim=0)
pad_vals = torch.zeros((pad,), dtype=torch.int32, device=device)
gray_padded = torch.cat([gray_code[sorted_idx], pad_vals], dim=0)
else:
gray_padded = gray_long[sorted_idx]
gray_padded = gray_code[sorted_idx]
# reshape to (num_blocks, block_size) and compute bitwise_or across dim=1
gray_blocks = gray_padded.view(num_blocks, block_size) # each row = block entries
# reduce with bitwise_or
reduced_code = gray_blocks[:, 0].clone()
for i in range(1, block_size):
reduced_code |= gray_blocks[:, i]
reduced_code = reduced_code.to(torch.int32) # match CUDA int32
gray_blocks = gray_padded.view(num_blocks, block_size)
reduced_code = gray_blocks
while reduced_code.shape[1] > 1:
half = reduced_code.shape[1] // 2
remainder = reduced_code.shape[1] % 2
left = reduced_code[:, :half * 2:2]
right = reduced_code[:, 1:half * 2:2]
merged = left | right
if remainder:
reduced_code = torch.cat([merged, reduced_code[:, -1:]], dim=1)
else:
reduced_code = merged
reduced_code = reduced_code.squeeze(1)
seglen_counts = _popcount_int32_tensor(reduced_code).to(torch.int32)
# 2) compute seglen (popcount per reduced_code) and seg (prefix sum)
seglen_counts = _popcount_int32_tensor(reduced_code.to(torch.int64)).to(torch.int32) # [num_blocks]
# seg: length num_blocks+1, seg[0]=0, seg[i+1]=cumsum(seglen_counts) up to i
seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device)
seg[0] = 0
if num_blocks > 0:
@ -336,30 +326,20 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2(
total = int(seg[-1].item())
# 3) scatter — produce valid_kernel_idx as concatenated ascending set-bit positions for each reduced_code row
if total == 0:
valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device)
return valid_kernel_idx, seg
return torch.empty((0,), dtype=torch.int32, device=device), seg
max_val = int(reduced_code.max().item())
V = max_val.bit_length() if max_val > 0 else 0
# If you know V externally, pass it instead or set here explicitly.
V = int(reduced_code.max().item()).bit_length() if reduced_code.max() > 0 else 0
if V == 0:
# no bits set anywhere
valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device)
return valid_kernel_idx, seg
return torch.empty((0,), dtype=torch.int32, device=device), seg
# build mask of shape (num_blocks, V): True where bit is set
bit_pos = torch.arange(0, V, dtype=torch.int64, device=device) # [V]
# shifted = reduced_code[:, None] >> bit_pos[None, :]
shifted = reduced_code.to(torch.int64).unsqueeze(1) >> bit_pos.unsqueeze(0)
bits = (shifted & 1).to(torch.bool) # (num_blocks, V)
bit_pos = torch.arange(0, V, dtype=torch.int32, device=device)
shifted = reduced_code.unsqueeze(1) >> bit_pos.unsqueeze(0)
bits = (shifted & 1).to(torch.bool)
positions = bit_pos.unsqueeze(0).expand(num_blocks, V)
valid_positions = positions[bits]
valid_kernel_idx = valid_positions.to(torch.int32).contiguous()
valid_kernel_idx = positions[bits].to(torch.int32).contiguous()
return valid_kernel_idx, seg
@ -425,35 +405,6 @@ def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache
return out, neighbor
class Voxel:
def __init__(
self,
origin: list,
voxel_size: float,
coords: torch.Tensor = None,
attrs: torch.Tensor = None,
layout = None,
device: torch.device = None
):
if layout is None:
layout = {}
self.origin = torch.tensor(origin, dtype=torch.float32, device=device)
self.voxel_size = voxel_size
self.coords = coords
self.attrs = attrs
self.layout = layout
self.device = device
@property
def position(self):
return (self.coords + 0.5) * self.voxel_size + self.origin[None, :]
def split_attrs(self):
return {
k: self.attrs[:, self.layout[k]]
for k in self.layout
}
class Mesh:
def __init__(self,
vertices,
@ -480,35 +431,3 @@ class Mesh:
def cpu(self):
return self.to('cpu')
class MeshWithVoxel(Mesh, Voxel):
def __init__(self,
vertices: torch.Tensor,
faces: torch.Tensor,
origin: list,
voxel_size: float,
coords: torch.Tensor,
attrs: torch.Tensor,
voxel_shape: torch.Size,
layout: Dict = {},
):
self.vertices = vertices.float()
self.faces = faces.int()
self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device)
self.voxel_size = voxel_size
self.coords = coords
self.attrs = attrs
self.voxel_shape = voxel_shape
self.layout = layout
def to(self, device, non_blocking=False):
return MeshWithVoxel(
self.vertices.to(device, non_blocking=non_blocking),
self.faces.to(device, non_blocking=non_blocking),
self.origin.tolist(),
self.voxel_size,
self.coords.to(device, non_blocking=non_blocking),
self.attrs.to(device, non_blocking=non_blocking),
self.voxel_shape,
self.layout,
)

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
from fractions import Fraction
from dataclasses import dataclass
from typing import List, Any, Dict, Optional, overload, Union, Tuple
from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, sparse_submanifold_conv3d
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
@ -210,6 +210,8 @@ class SparseResBlockC2S3d(nn.Module):
def forward(self, x, subdiv = None):
if self.pred_subdiv:
dtype = next(self.to_subdiv.parameters()).dtype
x = x.to(dtype)
subdiv = self.to_subdiv(x)
norm1 = self.norm1.to(torch.float32)
norm2 = self.norm2.to(torch.float32)
@ -987,114 +989,7 @@ def convert_module_to_f16(l):
for p in l.parameters():
p.data = p.data.half()
class SparseUnetVaeEncoder(nn.Module):
"""
Sparse Swin Transformer Unet VAE model.
"""
def __init__(
self,
in_channels: int,
model_channels: List[int],
latent_channels: int,
num_blocks: List[int],
block_type: List[str],
down_block_type: List[str],
block_args: List[Dict[str, Any]],
use_fp16: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.num_blocks = num_blocks
self.dtype = torch.float16 if use_fp16 else torch.float32
self.input_layer = SparseLinear(in_channels, model_channels[0])
self.to_latent = SparseLinear(model_channels[-1], 2 * latent_channels)
self.blocks = nn.ModuleList([])
for i in range(len(num_blocks)):
self.blocks.append(nn.ModuleList([]))
for j in range(num_blocks[i]):
self.blocks[-1].append(
globals()[block_type[i]](
model_channels[i],
**block_args[i],
)
)
if i < len(num_blocks) - 1:
self.blocks[-1].append(
globals()[down_block_type[i]](
model_channels[i],
model_channels[i+1],
**block_args[i],
)
)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def forward(self, x: SparseTensor, sample_posterior=False, return_raw=False):
h = self.input_layer(x)
h = h.type(self.dtype)
for i, res in enumerate(self.blocks):
for j, block in enumerate(res):
h = block(h)
h = h.type(x.dtype)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.to_latent(h)
# Sample from the posterior distribution
mean, logvar = h.feats.chunk(2, dim=-1)
if sample_posterior:
std = torch.exp(0.5 * logvar)
z = mean + std * torch.randn_like(std)
else:
z = mean
z = h.replace(z)
if return_raw:
return z, mean, logvar
else:
return z
class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder):
def __init__(
self,
model_channels: List[int],
latent_channels: int,
num_blocks: List[int],
block_type: List[str],
down_block_type: List[str],
block_args: List[Dict[str, Any]],
use_fp16: bool = False,
):
super().__init__(
6,
model_channels,
latent_channels,
num_blocks,
block_type,
down_block_type,
block_args,
use_fp16,
)
def forward(self, vertices: SparseTensor, intersected: SparseTensor, sample_posterior=False, return_raw=False):
x = vertices.replace(torch.cat([
vertices.feats - 0.5,
intersected.feats.float() - 0.5,
], dim=1))
return super().forward(x, sample_posterior, return_raw)
class SparseUnetVaeDecoder(nn.Module):
"""
Sparse Swin Transformer Unet VAE model.
"""
def __init__(
self,
out_channels: int,
@ -1218,10 +1113,10 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
N = coords.shape[0]
# compute flat keys for all coords (prepend batch 0 same as original code)
b = torch.zeros((N,), dtype=torch.long, device=device)
x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long()
x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32)
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
flat_keys = b * (W * H * D) + x * (H * D) + y * D + z
values = torch.arange(N, dtype=torch.long, device=device)
values = torch.arange(N, dtype=torch.int32, device=device)
DEFAULT_VAL = 0xffffffff # sentinel used in original code
return TorchHashMap(flat_keys, values, DEFAULT_VAL)
@ -1295,13 +1190,12 @@ def flexible_dual_grid_to_mesh(
# Extract mesh
N = dual_vertices.shape[0]
mesh_vertices = (coords.float() + dual_vertices) / (2 * N) - 0.5
if hashmap_builder is None:
# build local TorchHashMap
device = coords.device
b = torch.zeros((N,), dtype=torch.long, device=device)
x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long()
x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32)
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
flat_keys = b * (W * H * D) + x * (H * D) + y * D + z
values = torch.arange(N, dtype=torch.long, device=device)
@ -1316,9 +1210,9 @@ def flexible_dual_grid_to_mesh(
M = connected_voxel.shape[0]
# flatten connected voxel coords and lookup
conn_flat_b = torch.zeros((M * 4,), dtype=torch.long, device=coords.device)
conn_x = connected_voxel.reshape(-1, 3)[:, 0].long()
conn_y = connected_voxel.reshape(-1, 3)[:, 1].long()
conn_z = connected_voxel.reshape(-1, 3)[:, 2].long()
conn_x = connected_voxel.reshape(-1, 3)[:, 0].to(torch.int32)
conn_y = connected_voxel.reshape(-1, 3)[:, 1].to(torch.int32)
conn_z = connected_voxel.reshape(-1, 3)[:, 2].to(torch.int32)
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
conn_flat = conn_flat_b * (W * H * D) + conn_x * (H * D) + conn_y * D + conn_z
@ -1526,17 +1420,18 @@ class Vae(nn.Module):
channels=[512, 128, 32],
)
@torch.no_grad()
def decode_shape_slat(self, slat, resolution: int):
self.shape_dec.set_resolution(resolution)
device = comfy.model_management.get_torch_device()
self.shape_dec = self.shape_dec.to(device)
return self.shape_dec(slat, return_subs=True)
@torch.no_grad()
def decode_tex_slat(self, slat, subs):
if self.txt_dec is None:
raise ValueError("Checkpoint doesn't include texture model")
return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5
# shouldn't be called (placeholder)
@torch.no_grad()
def decode(
self,
@ -1546,17 +1441,4 @@ class Vae(nn.Module):
):
meshes, subs = self.decode_shape_slat(shape_slat, resolution)
tex_voxels = self.decode_tex_slat(tex_slat, subs)
out_mesh = []
for m, v in zip(meshes, tex_voxels):
out_mesh.append(
MeshWithVoxel(
m.vertices, m.faces,
origin = [-0.5, -0.5, -0.5],
voxel_size = 1 / resolution,
coords = v.coords[:, 1:],
attrs = v.feats,
voxel_shape = torch.Size([*v.shape, *v.spatial_shape]),
layout=self.pbr_attr_layout
)
)
return out_mesh
return tex_voxels

View File

@ -1,7 +1,7 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types
from comfy.ldm.trellis2.vae import SparseTensor
from comfy.utils import ProgressBar
from comfy.utils import ProgressBar, lanczos
import torch.nn.functional as TF
import comfy.model_management
from PIL import Image
@ -102,9 +102,7 @@ def run_conditioning(model, image, mask, include_1024 = True, background_color =
cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb)
def prepare_tensor(img, size):
resized = torch.nn.functional.interpolate(
img.unsqueeze(0), size=(size, size), mode='bicubic', align_corners=False
)
resized = lanczos(img.unsqueeze(0), size, size)
return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device)
model_internal.image_size = 512
@ -148,10 +146,16 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
@classmethod
def execute(cls, samples, structure_output, vae, resolution):
patcher = vae.patcher
device = comfy.model_management.get_torch_device()
comfy.model_management.load_model_gpu(patcher)
vae = vae.first_stage_model
decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
samples = samples["samples"]
samples = samples.squeeze(-1).transpose(1, 2).to(device)
std = shape_slat_normalization["std"].to(samples)
mean = shape_slat_normalization["mean"].to(samples)
samples = SparseTensor(feats = samples, coords=coords)
@ -179,10 +183,16 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
@classmethod
def execute(cls, samples, structure_output, vae, shape_subs):
patcher = vae.patcher
device = comfy.model_management.get_torch_device()
comfy.model_management.load_model_gpu(patcher)
vae = vae.first_stage_model
decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
samples = samples["samples"]
samples = samples.squeeze(-1).transpose(1, 2).to(device)
std = tex_slat_normalization["std"].to(samples)
mean = tex_slat_normalization["mean"].to(samples)
samples = SparseTensor(feats = samples, coords=coords)