Cleanup VAE

This commit is contained in:
kijai 2026-05-23 02:43:08 +03:00
parent 4585a731c1
commit 3edbf7c4a7
4 changed files with 160 additions and 535 deletions

View File

@ -1,8 +1,8 @@
# will contain every cuda -> pytorch operation
from typing import Optional, Tuple
import torch
import comfy.model_management
UINT32_SENTINEL = 0xFFFFFFFF
@ -26,9 +26,7 @@ class TorchHashMap:
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
self._n = self.sorted_keys.numel()
# Chunk size for lookup_flat. At ~530M flat keys (large mesh extraction),
# the unchunked path allocates ~5 full-size int64 temporaries (4 GB each) +
# bool masks + the int32 output. Chunking caps each transient to ~CHUNK rows.
# Chunk size for lookup_flat, caps each transient to ~CHUNK rows.
_LOOKUP_CHUNK = 1 << 23 # 8M rows ≈ 64 MB per int64 temp
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
@ -119,57 +117,13 @@ def build_submanifold_neighbor_map(
def get_recommended_chunk_mem(
device=None,
safety_fraction: float = 0.4,
safety_fraction: float = 0.2,
min_gb: float = 0.25,
max_gb: float = 8.0,
max_gb: float = 2.0,
):
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(device)
if device.type == 'cuda':
try:
idx = device.index if device.index is not None else 0
free_bytes, total_bytes = torch.cuda.mem_get_info(idx)
free_gb = free_bytes / (1024 ** 3)
total_gb = total_bytes / (1024 ** 3)
recommended = free_gb * safety_fraction
result = max(min_gb, min(recommended, max_gb))
return result
except Exception:
try:
idx = device.index if device.index is not None else 0
total_gb = torch.cuda.get_device_properties(idx).total_memory / (1024 ** 3)
except Exception:
total_gb = 16.0
if total_gb < 12:
result = 0.5
elif total_gb < 16:
result = 0.75
elif total_gb < 24:
result = 1.0
elif total_gb < 32:
result = 2.0
elif total_gb < 48:
result = 4.0
else:
result = 6.0
return result
else:
try:
import psutil
avail_gb = psutil.virtual_memory().available / (1024 ** 3)
recommended = avail_gb * safety_fraction
result = max(min_gb, min(recommended, max_gb))
return result
except ImportError:
return min_gb
"""Pick a chunk-memory budget (in GB) for sparse conv batching."""
free_gb = comfy.model_management.get_free_memory(device) / (1024 ** 3)
return max(min_gb, min(free_gb * safety_fraction, max_gb))
def sparse_submanifold_conv3d(
feats: torch.Tensor,
@ -179,24 +133,16 @@ def sparse_submanifold_conv3d(
bias: Optional[torch.Tensor],
neighbor_cache: Optional[torch.Tensor],
dilation: tuple,
max_chunk_mem_gb: float = 6.0,
accumulate_f32: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if feats.shape[0] == 0:
Co = weight.shape[0]
return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None
if len(shape) == 5:
_, _, W, H, D = shape
else:
W, H, D = shape
W, H, D = shape
Co, Kw, Kh, Kd, Ci = weight.shape
V = Kw * Kh * Kd
device = feats.device
sentinel = -1
max_chunk_mem_gb = get_recommended_chunk_mem(device)
if neighbor_cache is None:
b_stride = W * H * D
@ -219,91 +165,37 @@ def sparse_submanifold_conv3d(
neighbor = neighbor_cache
N_pts = feats.shape[0]
sentinel = -1
if accumulate_f32:
weight_T = weight.view(Co, V * Ci).to(torch.float32).T.contiguous()
else:
weight_T = weight.view(Co, V * Ci).to(feats.dtype).T.contiguous()
weight_T = weight.view(Co, V * Ci).T
output = torch.empty(N_pts, Co, device=device, dtype=feats.dtype)
# ------------------------------------------------------------------
# Chunk size from memory budget
# ------------------------------------------------------------------
bytes_per_elem = 4 if accumulate_f32 else feats.element_size()
mem_per_row = V * Ci * bytes_per_elem
# Chunk size from memory budget. The dominant peak is `gathered`, of shape (chunk, V, Ci) in feats.dtype.
max_chunk_mem_gb = get_recommended_chunk_mem(device)
mem_per_row = V * Ci * feats.element_size()
max_chunk_mem = max_chunk_mem_gb * (1024 ** 3)
chunk_size = max(1, int(max_chunk_mem / mem_per_row))
chunk_size = min(chunk_size, N_pts)
# fp32 matmul scratch — sized to the largest chunk, reused each iteration.
chunk_buf = torch.empty(chunk_size, Co, device=device, dtype=torch.float32) if accumulate_f32 else None
# ------------------------------------------------------------------
# Chunked forward pass
# Each iteration:
# 1. gather (chunk, V, Ci) memory bound
# 2. mask zero invalids in-place, no extra alloc
# 3. reshape (chunk, V*Ci)
# 4. GEMM (chunk, V*Ci) @ (V*Ci, Co) → (chunk, Co) cuBLAS
# written into the scratch buf (fp32) or output slice (fp16) via out=
# 5. (fp32 path) cast scratch chunk to fp16 and copy into output slice
# ------------------------------------------------------------------
for start in range(0, N_pts, chunk_size):
end = min(start + chunk_size, N_pts)
actual_chunk = end - start
# (chunk, V) int32
chunk_neighbor = neighbor[start:end]
chunk_valid = chunk_neighbor != sentinel
# clamp(-1 -> 0) keeps invalid indices in-range so the gather is safe
chunk_idx = chunk_neighbor.clamp(min=0)
# Clamp sentinel -1 → 0 for safe indexing. No clone of the full map.
chunk_idx = chunk_neighbor.clamp(min=0).long()
# Gather: (chunk, V, Ci). Memory-bound, single index_select.
# (chunk, V, Ci) gather, then in-place zero of invalid neighbors.
gathered = feats[chunk_idx]
# Zero invalid neighbours in-place. gathered is a fresh tensor from
# advanced indexing, so in-place mutation is safe.
gathered.mul_(chunk_valid.unsqueeze(-1))
# Reshape to (chunk, V*Ci)
# GEMM (chunk, V*Ci) @ (V*Ci, Co) -> (chunk, Co), written to output[start:end].
gathered_flat = gathered.view(actual_chunk, V * Ci)
if accumulate_f32:
gathered_flat = gathered_flat.to(torch.float32)
torch.matmul(gathered_flat, weight_T, out=chunk_buf[:actual_chunk])
output[start:end] = chunk_buf[:actual_chunk].to(feats.dtype)
else:
torch.matmul(gathered_flat, weight_T, out=output[start:end])
torch.matmul(gathered_flat, weight_T, out=output[start:end])
if bias is not None:
output += bias.unsqueeze(0).to(output.dtype)
return output, neighbor
class Mesh:
def __init__(self,
vertices,
faces,
vertex_attrs=None
):
self.vertices = vertices.float()
self.faces = faces.int()
self.vertex_attrs = vertex_attrs
@property
def device(self):
return self.vertices.device
def to(self, device, non_blocking=False):
return Mesh(
self.vertices.to(device, non_blocking=non_blocking),
self.faces.to(device, non_blocking=non_blocking),
self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None,
)
def cuda(self, non_blocking=False):
return self.to('cuda', non_blocking=non_blocking)
def cpu(self):
return self.to('cpu')

View File

@ -1,12 +1,13 @@
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from fractions import Fraction
from dataclasses import dataclass
from typing import List, Any, Dict, Optional, overload, Union, Tuple
from comfy.ldm.trellis2.flexgemm import TorchHashMap, Mesh, sparse_submanifold_conv3d
from typing import List, Any, Dict, Optional, overload, Union
import comfy.ops
from comfy.ldm.trellis2.flexgemm import TorchHashMap, sparse_submanifold_conv3d
ops = comfy.ops.disable_weight_init
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
@ -73,10 +74,10 @@ def sparse_conv3d_forward(self, x):
out = x.replace(out)
return out
class LayerNorm32(nn.LayerNorm):
class LayerNorm32(ops.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
w = self.weight.to(x.dtype) if self.weight is not None else None
b = self.bias.to(x.dtype) if self.bias is not None else None
w = self.weight.to(x) if self.weight is not None else None
b = self.bias.to(x) if self.bias is not None else None
return F.layer_norm(x, self.normalized_shape, w, b, self.eps)
class SparseConvNeXtBlock3d(nn.Module):
@ -93,12 +94,13 @@ class SparseConvNeXtBlock3d(nn.Module):
self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.conv = SparseConv3d(channels, channels, 3)
self.mlp = nn.Sequential(
nn.Linear(channels, int(channels * mlp_ratio)),
ops.Linear(channels, int(channels * mlp_ratio)),
nn.SiLU(inplace=True),
nn.Linear(int(channels * mlp_ratio), channels),
ops.Linear(int(channels * mlp_ratio), channels),
)
def _forward(self, x):
x = x.to(dtype=self.conv.weight.dtype, device=self.conv.weight.device)
h = self.conv(x)
h = h.replace(self.norm(h.feats))
h = h.replace(self.mlp(h.feats))
@ -141,7 +143,7 @@ class SparseSpatial2Channel(nn.Module):
out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM]))
out._scale = tuple([s * self.factor for s in x._scale])
out._spatial_cache = x._spatial_cache
out._spatial_cache = dict(x._spatial_cache)
if cache is None:
x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx))
@ -180,7 +182,7 @@ class SparseChannel2Spatial(nn.Module):
out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM]))
out._scale = tuple([s / self.factor for s in x._scale])
if cache is not None: # only keep cache when subdiv following it
out._spatial_cache = x._spatial_cache
out._spatial_cache = dict(x._spatial_cache)
return out
class SparseResBlockC2S3d(nn.Module):
@ -226,10 +228,6 @@ class SparseResBlockC2S3d(nn.Module):
else:
return h
@dataclass
class config:
CONV = "flexgemm"
FLEX_GEMM_HASHMAP_RATIO = 2.0
class VarLenTensor:
@ -238,18 +236,6 @@ class VarLenTensor:
self.layout = layout if layout is not None else [slice(0, feats.shape[0])]
self._cache = {}
@staticmethod
def layout_from_seqlen(seqlen: list) -> List[slice]:
"""
Create a layout from a tensor of sequence lengths.
"""
layout = []
start = 0
for l in seqlen:
layout.append(slice(start, start + l))
start += l
return layout
@staticmethod
def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor':
"""
@ -376,29 +362,22 @@ class VarLenTensor:
feats=feats,
layout=self.layout,
)
new_tensor._cache = self._cache
# Shallow-copy so derived tensors don't share-by-reference the cache
# dict — see SparseTensor.replace for rationale.
new_tensor._cache = dict(self._cache)
return new_tensor
def to_dense(self, max_length=None) -> torch.Tensor:
N = len(self)
L = max_length or self.seqlen.max().item()
spatial = self.feats.shape[1:]
idx = torch.arange(L, device=self.device).unsqueeze(0).expand(N, L)
mask = (idx < self.seqlen.unsqueeze(1))
mapping = mask.reshape(-1).cumsum(dim=0) - 1
dense = self.feats[mapping]
dense = dense.reshape(N, L, *spatial)
return dense, mask
def __neg__(self) -> 'VarLenTensor':
return self.replace(-self.feats)
def __elemwise__(self, other: Union[torch.Tensor, 'VarLenTensor'], op: callable) -> 'VarLenTensor':
if isinstance(other, torch.Tensor):
# Try per-batch [B, C] -> per-token [T, C] broadcast. RuntimeError
# fires for incompatible shapes; fall through and let op() handle.
try:
other = torch.broadcast_to(other, self.shape)
other = other[self.batch_boardcast_map]
except:
except RuntimeError:
pass
if isinstance(other, VarLenTensor):
other = other.feats
@ -459,40 +438,6 @@ class VarLenTensor:
new_tensor = VarLenTensor(feats=new_feats, layout=new_layout)
return new_tensor
def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
if isinstance(dim, int):
dim = (dim,)
if op =='mean':
red = self.feats.mean(dim=dim, keepdim=keepdim)
elif op =='sum':
red = self.feats.sum(dim=dim, keepdim=keepdim)
elif op == 'prod':
red = self.feats.prod(dim=dim, keepdim=keepdim)
else:
raise ValueError(f"Unsupported reduce operation: {op}")
if dim is None or 0 in dim:
return red
red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen)
return red
def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
return self.reduce(op='mean', dim=dim, keepdim=keepdim)
def sum(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
return self.reduce(op='sum', dim=dim, keepdim=keepdim)
def prod(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
return self.reduce(op='prod', dim=dim, keepdim=keepdim)
def std(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
mean = self.mean(dim=dim, keepdim=True)
mean2 = self.replace(self.feats ** 2).mean(dim=dim, keepdim=True)
std = (mean2 - mean ** 2).sqrt()
return std
def __repr__(self) -> str:
return f"VarLenTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})"
@ -507,8 +452,6 @@ def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]:
class SparseTensor(VarLenTensor):
SparseTensorData = None
@overload
def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ...
@ -516,14 +459,6 @@ class SparseTensor(VarLenTensor):
def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ...
def __init__(self, *args, **kwargs):
# Lazy import of sparse tensor backend
if self.SparseTensorData is None:
import importlib
if config.CONV == 'torchsparse':
self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor
elif config.CONV == 'spconv':
self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
method_id = 0
if len(args) != 0:
method_id = 0 if isinstance(args[0], torch.Tensor) else 1
@ -542,17 +477,10 @@ class SparseTensor(VarLenTensor):
shape = kwargs['shape']
del kwargs['shape']
if config.CONV == 'torchsparse':
self.data = self.SparseTensorData(feats, coords, **kwargs)
elif config.CONV == 'spconv':
spatial_shape = list(coords.max(0)[0] + 1)
self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs)
self.data._features = feats
else:
self.data = {
'feats': feats,
'coords': coords,
}
self.data = {
'feats': feats,
'coords': coords,
}
elif method_id == 1:
data, shape = args + (None,) * (2 - len(args))
if 'data' in kwargs:
@ -581,17 +509,6 @@ class SparseTensor(VarLenTensor):
coords = torch.cat(coords, dim=0)
return SparseTensor(feats, coords)
def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
Convert a SparseTensor to list of tensors.
"""
feats_list = []
coords_list = []
for s in self.layout:
feats_list.append(self.feats[s])
coords_list.append(self.coords[s])
return feats_list, coords_list
def __len__(self) -> int:
return len(self.layout)
@ -634,39 +551,19 @@ class SparseTensor(VarLenTensor):
@property
def feats(self) -> torch.Tensor:
if config.CONV == 'torchsparse':
return self.data.F
elif config.CONV == 'spconv':
return self.data.features
else:
return self.data['feats']
return self.data['feats']
@feats.setter
def feats(self, value: torch.Tensor):
if config.CONV == 'torchsparse':
self.data.F = value
elif config.CONV == 'spconv':
self.data.features = value
else:
self.data['feats'] = value
self.data['feats'] = value
@property
def coords(self) -> torch.Tensor:
if config.CONV == 'torchsparse':
return self.data.C
elif config.CONV == 'spconv':
return self.data.indices
else:
return self.data['coords']
return self.data['coords']
@coords.setter
def coords(self, value: torch.Tensor):
if config.CONV == 'torchsparse':
self.data.C = value
elif config.CONV == 'spconv':
self.data.indices = value
else:
self.data['coords'] = value
self.data['coords'] = value
@property
def dtype(self):
@ -773,71 +670,20 @@ class SparseTensor(VarLenTensor):
return sparse_unbind(self, dim)
def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
if config.CONV == 'torchsparse':
new_data = self.SparseTensorData(
feats=feats,
coords=self.data.coords if coords is None else coords,
stride=self.data.stride,
spatial_range=self.data.spatial_range,
)
new_data._caches = self.data._caches
elif config.CONV == 'spconv':
new_data = self.SparseTensorData(
self.data.features.reshape(self.data.features.shape[0], -1),
self.data.indices,
self.data.spatial_shape,
self.data.batch_size,
self.data.grid,
self.data.voxel_num,
self.data.indice_dict
)
new_data._features = feats
new_data.benchmark = self.data.benchmark
new_data.benchmark_record = self.data.benchmark_record
new_data.thrust_allocator = self.data.thrust_allocator
new_data._timer = self.data._timer
new_data.force_algo = self.data.force_algo
new_data.int8_scale = self.data.int8_scale
if coords is not None:
new_data.indices = coords
else:
new_data = {
'feats': feats,
'coords': self.data['coords'] if coords is None else coords,
}
new_tensor = SparseTensor(
new_data = {
'feats': feats,
'coords': self.data['coords'] if coords is None else coords,
}
return SparseTensor(
new_data,
shape=torch.Size([self._shape[0]] + list(feats.shape[1:])) if self._shape is not None else None,
scale=self._scale,
spatial_cache=self._spatial_cache
# Shallow-copy the cache: each derived tensor gets its own dict, so
# adding/overwriting entries on one doesn't leak to siblings.
# Cached tensors themselves are still shared by reference (safe
# because they're read-only after populate).
spatial_cache=dict(self._spatial_cache),
)
return new_tensor
def to_dense(self) -> torch.Tensor:
if config.CONV == 'torchsparse':
return self.data.dense()
elif config.CONV == 'spconv':
return self.data.dense()
else:
spatial_shape = self.spatial_shape
ret = torch.zeros(*self.shape, *spatial_shape, dtype=self.dtype, device=self.device)
idx = [self.coords[:, 0], slice(None)] + self.coords[:, 1:].unbind(1)
ret[tuple(idx)] = self.feats
return ret
@staticmethod
def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
N, C = dim
x = torch.arange(aabb[0], aabb[3] + 1)
y = torch.arange(aabb[1], aabb[4] + 1)
z = torch.arange(aabb[2], aabb[5] + 1)
coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
coords = torch.cat([
torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
coords.repeat(N, 1),
], dim=1).to(dtype=torch.int32, device=device)
feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
return SparseTensor(feats=feats, coords=coords)
def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
new_cache = {}
@ -853,10 +699,12 @@ class SparseTensor(VarLenTensor):
def __elemwise__(self, other: Union[torch.Tensor, VarLenTensor], op: callable) -> 'SparseTensor':
if isinstance(other, torch.Tensor):
# Try per-batch [B, C] -> per-voxel [N, C] broadcast. RuntimeError
# fires for incompatible shapes; fall through and let op() handle.
try:
other = torch.broadcast_to(other, self.shape)
other = other[self.batch_boardcast_map]
except:
except RuntimeError:
pass
if isinstance(other, VarLenTensor):
other = other.feats
@ -901,12 +749,6 @@ class SparseTensor(VarLenTensor):
new_tensor.register_spatial_cache('layout', new_layout)
return new_tensor
def clear_spatial_cache(self) -> None:
"""
Clear all spatial caches.
"""
self._spatial_cache = {}
def register_spatial_cache(self, key, value) -> None:
"""
Register a spatial cache.
@ -961,7 +803,7 @@ def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
# allow operations.Linear inheritance
class SparseLinear:
def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None, operations=nn, *args, **kwargs):
def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None, operations=ops, *args, **kwargs):
class _SparseLinear(operations.Linear):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
@ -971,24 +813,6 @@ class SparseLinear:
return _SparseLinear(in_features, out_features, bias=bias, device=device, dtype=dtype, *args, **kwargs)
MIX_PRECISION_MODULES = (
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Linear,
SparseConv3d,
SparseLinear,
)
def convert_module_to_f16(l):
if isinstance(l, MIX_PRECISION_MODULES):
for p in l.parameters():
p.data = p.data.half()
class SparseUnetVaeDecoder(nn.Module):
def __init__(
self,
@ -999,17 +823,13 @@ class SparseUnetVaeDecoder(nn.Module):
block_type: List[str],
up_block_type: List[str],
block_args: List[Dict[str, Any]],
use_fp16: bool = False,
pred_subdiv: bool = True,
):
super().__init__()
self.out_channels = out_channels
self.model_channels = model_channels
self.num_blocks = num_blocks
self.use_fp16 = use_fp16
self.pred_subdiv = pred_subdiv
self.dtype = torch.float16 if use_fp16 else torch.float32
self.low_vram = False
self.output_layer = SparseLinear(model_channels[-1], out_channels)
self.from_latent = SparseLinear(latent_channels, model_channels[0])
@ -1033,17 +853,9 @@ class SparseUnetVaeDecoder(nn.Module):
**block_args[i],
)
)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def forward(self, x: SparseTensor, guide_subs: Optional[List[SparseTensor]] = None, return_subs: bool = False) -> SparseTensor:
dtype = next(self.from_latent.parameters()).dtype
device = next(self.from_latent.parameters()).device
x.feats = x.feats.to(dtype).to(device)
h = self.from_latent(x)
h = h.type(self.dtype)
subs = []
for i, res in enumerate(self.blocks):
for j, block in enumerate(res):
@ -1055,7 +867,6 @@ class SparseUnetVaeDecoder(nn.Module):
h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None)
else:
h = block(h)
h = h.type(x.feats.dtype)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.output_layer(h)
if return_subs:
@ -1064,9 +875,7 @@ class SparseUnetVaeDecoder(nn.Module):
return h
def upsample(self, x: SparseTensor, upsample_times: int) -> torch.Tensor:
h = self.from_latent(x)
h = h.type(self.dtype)
for i, res in enumerate(self.blocks):
if i == upsample_times:
return h.coords
@ -1087,13 +896,9 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
up_block_type: List[str],
block_args: List[Dict[str, Any]],
voxel_margin: float = 0.5,
use_fp16: bool = False,
):
self.resolution = resolution
self.voxel_margin = voxel_margin
# cache for a TorchHashMap instance
self._torch_hashmap_cache = None
super().__init__(
7,
model_channels,
@ -1102,22 +907,11 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
block_type,
up_block_type,
block_args,
use_fp16,
)
def set_resolution(self, resolution: int) -> None:
self.resolution = resolution
def _build_or_get_hashmap(self, coords: torch.Tensor, grid_size: torch.Tensor):
device = coords.device
N = coords.shape[0]
_, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
flat_keys = coords[:, 0].long() * (H * D)
flat_keys.add_(coords[:, 1].long() * D)
flat_keys.add_(coords[:, 2].long())
values = torch.arange(N, dtype=torch.int32, device=device)
return TorchHashMap(flat_keys, values, 0xffffffff)
def forward(self, x: SparseTensor, gt_intersected: SparseTensor = None, **kwargs):
decoded = super().forward(x, **kwargs)
out_list = list(decoded) if isinstance(decoded, tuple) else [decoded]
@ -1125,13 +919,11 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin)
intersected = h.replace(h.feats[..., 3:6] > 0)
quad_lerp = h.replace(F.softplus(h.feats[..., 6:7]))
mesh = [Mesh(*flexible_dual_grid_to_mesh(
mesh = [flexible_dual_grid_to_mesh(
v.coords[:, 1:], v.feats, i.feats, q.feats,
aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
grid_size=self.resolution,
train=False,
hashmap_builder=self._build_or_get_hashmap,
)) for v, i, q in zip(vertices, intersected, quad_lerp)]
) for v, i, q in zip(vertices, intersected, quad_lerp)]
out_list[0] = mesh
return out_list[0] if len(out_list) == 1 else tuple(out_list)
@ -1140,11 +932,9 @@ def flexible_dual_grid_to_mesh(
dual_vertices: torch.Tensor,
intersected_flag: torch.Tensor,
split_weight: Union[torch.Tensor, None],
aabb: Union[list, tuple, np.ndarray, torch.Tensor],
voxel_size: Union[float, list, tuple, np.ndarray, torch.Tensor] = None,
grid_size: Union[int, list, tuple, np.ndarray, torch.Tensor] = None,
train: bool = False,
hashmap_builder=None, # optional callable for building/caching a TorchHashMap
aabb: Union[list, tuple, torch.Tensor],
voxel_size: Union[float, list, tuple, torch.Tensor] = None,
grid_size: Union[int, list, tuple, torch.Tensor] = None,
):
device = coords.device
@ -1159,46 +949,28 @@ def flexible_dual_grid_to_mesh(
flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2") or flexible_dual_grid_to_mesh.quad_split_2.device != device:
flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train") or flexible_dual_grid_to_mesh.quad_split_train.device != device:
flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=device, requires_grad=False)
# AABB
if isinstance(aabb, (list, tuple)):
aabb = np.array(aabb)
if isinstance(aabb, np.ndarray):
aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
# Voxel size
if voxel_size is not None:
if isinstance(voxel_size, float):
voxel_size = [voxel_size, voxel_size, voxel_size]
if isinstance(voxel_size, (list, tuple)):
voxel_size = np.array(voxel_size)
if isinstance(voxel_size, np.ndarray):
voxel_size = torch.tensor(voxel_size, dtype=torch.float32, device=coords.device)
voxel_size = [voxel_size] * 3
voxel_size = torch.tensor(voxel_size, dtype=torch.float32, device=device)
grid_size = ((aabb[1] - aabb[0]) / voxel_size).round().int()
else:
if isinstance(grid_size, int):
grid_size = [grid_size, grid_size, grid_size]
if isinstance(grid_size, (list, tuple)):
grid_size = np.array(grid_size)
if isinstance(grid_size, np.ndarray):
grid_size = torch.tensor(grid_size, dtype=torch.int32, device=coords.device)
grid_size = [grid_size] * 3
grid_size = torch.tensor(grid_size, dtype=torch.int32, device=device)
voxel_size = (aabb[1] - aabb[0]) / grid_size
# Extract mesh
N = dual_vertices.shape[0]
if hashmap_builder is None:
device = coords.device
_, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
flat_keys = coords[:, 0].long() * (H * D)
flat_keys.add_(coords[:, 1].long() * D)
flat_keys.add_(coords[:, 2].long())
values = torch.arange(N, dtype=torch.long, device=device)
torch_hashmap = TorchHashMap(flat_keys, values, 0xffffffff)
else:
torch_hashmap = hashmap_builder(coords, grid_size)
_, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
flat_keys = coords[:, 0].long() * (H * D)
flat_keys.add_(coords[:, 1].long() * D)
flat_keys.add_(coords[:, 2].long())
values = torch.arange(N, dtype=torch.int32, device=device)
torch_hashmap = TorchHashMap(flat_keys, values, 0xffffffff)
# Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3]
n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,)
@ -1253,55 +1025,34 @@ class ChannelLayerNorm32(LayerNorm32):
return x
class UpsampleBlock3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
mode = "conv",
):
assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if mode == "conv":
self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
elif mode == "nearest":
assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
self.conv = ops.Conv3d(in_channels, out_channels * 8, 3, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(self, "conv"):
x = self.conv(x)
return pixel_shuffle_3d(x, 2)
else:
return F.interpolate(x, scale_factor=2, mode="nearest")
def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
return ChannelLayerNorm32(*args, **kwargs)
return pixel_shuffle_3d(self.conv(x), 2)
class ResBlock3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
norm_type = "layer",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.norm1 = norm_layer(norm_type, channels)
self.norm2 = norm_layer(norm_type, self.out_channels)
self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
self.conv2 = nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)
self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
self.norm1 = ChannelLayerNorm32(channels)
self.norm2 = ChannelLayerNorm32(self.out_channels)
self.conv1 = ops.Conv3d(channels, self.out_channels, 3, padding=1)
self.conv2 = ops.Conv3d(self.out_channels, self.out_channels, 3, padding=1)
self.skip_connection = ops.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.norm1(x)
h = F.silu(h)
dtype = next(self.conv1.parameters()).dtype
h = h.to(dtype)
h = self.conv1(h)
h = self.norm2(h)
h = F.silu(h)
@ -1318,8 +1069,6 @@ class SparseStructureDecoder(nn.Module):
num_res_blocks: int,
channels: List[int],
num_res_blocks_middle: int = 2,
norm_type = "layer",
use_fp16: bool = True,
):
super().__init__()
self.out_channels = out_channels
@ -1327,11 +1076,8 @@ class SparseStructureDecoder(nn.Module):
self.num_res_blocks = num_res_blocks
self.channels = channels
self.num_res_blocks_middle = num_res_blocks_middle
self.norm_type = norm_type
self.use_fp16 = use_fp16
self.dtype = torch.float16 if use_fp16 else torch.float32
self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
self.input_layer = ops.Conv3d(latent_channels, channels[0], 3, padding=1)
self.middle_block = nn.Sequential(*[
ResBlock3d(channels[0], channels[0])
@ -1350,92 +1096,82 @@ class SparseStructureDecoder(nn.Module):
)
self.out_layer = nn.Sequential(
norm_layer(norm_type, channels[-1]),
ChannelLayerNorm32(channels[-1]),
nn.SiLU(),
nn.Conv3d(channels[-1], out_channels, 3, padding=1)
ops.Conv3d(channels[-1], out_channels, 3, padding=1)
)
if use_fp16:
self.convert_to_fp16()
def device(self) -> torch.device:
return next(self.parameters()).device
def convert_to_fp16(self) -> None:
self.use_fp16 = True
self.dtype = torch.float16
self.blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = next(self.input_layer.parameters()).dtype
x = x.to(dtype)
h = self.input_layer(x)
h = h.type(self.dtype)
h = self.middle_block(h)
for block in self.blocks:
h = block(h)
h = h.type(x.dtype)
h = self.out_layer(h)
return h
class Vae(nn.Module):
def __init__(self, init_txt_model, init_txt_model_only, operations=None):
class ShapeVae(nn.Module):
"""Decoder bundle from the Trellis2 shape checkpoint: structure decoder
(32^3 SS latent 64^3 dense occupancy) and shape decoder (sparse latent
mesh + per-stage subdivisions)."""
def __init__(self):
super().__init__()
operations = operations or torch.nn
if init_txt_model or init_txt_model_only:
self.txt_dec = SparseUnetVaeDecoder(
out_channels=6,
model_channels=[1024, 512, 256, 128, 64],
latent_channels=32,
num_blocks=[4, 16, 8, 4, 0],
block_type=["SparseConvNeXtBlock3d"] * 5,
up_block_type=["SparseResBlockC2S3d"] * 4,
block_args=[{}, {}, {}, {}, {}],
pred_subdiv=False
)
if not init_txt_model_only:
self.shape_dec = FlexiDualGridVaeDecoder(
resolution=256,
model_channels=[1024, 512, 256, 128, 64],
latent_channels=32,
num_blocks=[4, 16, 8, 4, 0],
block_type=["SparseConvNeXtBlock3d"] * 5,
up_block_type=["SparseResBlockC2S3d"] * 4,
block_args=[{}, {}, {}, {}, {}],
)
self.struct_dec = SparseStructureDecoder(
out_channels=1,
latent_channels=8,
num_res_blocks=2,
num_res_blocks_middle=2,
channels=[512, 128, 32],
)
self.shape_dec = FlexiDualGridVaeDecoder(
resolution=256,
model_channels=[1024, 512, 256, 128, 64],
latent_channels=32,
num_blocks=[4, 16, 8, 4, 0],
block_type=["SparseConvNeXtBlock3d"] * 5,
up_block_type=["SparseResBlockC2S3d"] * 4,
block_args=[{}, {}, {}, {}, {}],
)
self.struct_dec = SparseStructureDecoder(
out_channels=1,
latent_channels=8,
num_res_blocks=2,
num_res_blocks_middle=2,
channels=[512, 128, 32],
)
self.register_buffer("resolution", torch.tensor(1024.0), persistent=False)
@torch.no_grad()
def decode_shape_slat(self, slat, resolution: int):
def decode_structure(self, x: torch.Tensor) -> torch.Tensor:
weight = self.struct_dec.input_layer.weight
x = x.to(dtype=weight.dtype, device=weight.device)
return self.struct_dec(x)
def decode_shape_slat(self, slat: 'SparseTensor', resolution: int):
weight = self.shape_dec.from_latent.weight
slat = slat.to(dtype=weight.dtype, device=weight.device)
self.shape_dec.set_resolution(resolution)
return self.shape_dec(slat, return_subs=True)
@torch.no_grad()
def decode_tex_slat(self, slat, subs):
if self.txt_dec is None:
raise ValueError("Checkpoint doesn't include texture model")
def upsample_shape(self, slat: 'SparseTensor', upsample_times: int) -> torch.Tensor:
weight = self.shape_dec.from_latent.weight
slat = slat.to(dtype=weight.dtype, device=weight.device)
return self.shape_dec.upsample(slat, upsample_times)
class TextureVae(nn.Module):
"""Decoder bundle from the Trellis2 texture checkpoint: sparse 3D
per-voxel color decoder, guided by subdivisions from a prior shape decode."""
def __init__(self):
super().__init__()
self.txt_dec = SparseUnetVaeDecoder(
out_channels=6,
model_channels=[1024, 512, 256, 128, 64],
latent_channels=32,
num_blocks=[4, 16, 8, 4, 0],
block_type=["SparseConvNeXtBlock3d"] * 5,
up_block_type=["SparseResBlockC2S3d"] * 4,
block_args=[{}, {}, {}, {}, {}],
pred_subdiv=False,
)
self.register_buffer("resolution", torch.tensor(1024.0), persistent=False)
def decode_tex_slat(self, slat: 'SparseTensor', subs):
weight = self.txt_dec.from_latent.weight
slat = slat.to(dtype=weight.dtype, device=weight.device)
return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5
# shouldn't be called (placeholder)
@torch.no_grad()
def decode(
self,
shape_slat: SparseTensor,
tex_slat: SparseTensor,
resolution: int,
):
meshes, subs = self.decode_shape_slat(shape_slat, resolution)
tex_voxels = self.decode_tex_slat(tex_slat, subs)
return tex_voxels

View File

@ -529,18 +529,18 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd or "txt_dec.blocks.3.4.conv2.weight" in sd: # trellis2 or trellis2 texture only
init_txt_model = False
init_txt_model_only = False
if "shape_dec.blocks.1.16.to_subdiv.weight" not in sd:
init_txt_model_only = True
if "txt_dec.blocks.1.16.norm1.weight" in sd:
init_txt_model = True
elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd: # trellis2 shape vae (struct_dec + shape_dec)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# TODO
self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model, init_txt_model_only= init_txt_model_only)
self.first_stage_model = comfy.ldm.trellis2.vae.ShapeVae()
elif "txt_dec.blocks.3.4.conv2.weight" in sd: # trellis2 texture vae
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# TODO
self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.first_stage_model = comfy.ldm.trellis2.vae.TextureVae()
elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}

View File

@ -205,8 +205,8 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
coords_list = [stage_tensor.coords for stage_tensor in stage_tensors]
subs.append(SparseTensor.from_tensor_list(feats_list, coords_list))
face_list = [m.faces for m in mesh]
vert_list = [m.vertices for m in mesh]
vert_list = [v.float() for v, f in mesh]
face_list = [f.int() for v, f in mesh]
if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list):
mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list))
else:
@ -286,12 +286,12 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
sample_tensor = samples["samples"]
sample_tensor = sample_tensor[:, :8]
batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
decoder = vae.first_stage_model.struct_dec
shape_vae = vae.first_stage_model
load_device = comfy.model_management.get_torch_device()
decoded_batches = []
for start in range(0, sample_tensor.shape[0], batch_number):
sample_chunk = sample_tensor[start:start + batch_number].to(load_device)
decoded_batches.append(decoder(sample_chunk) > 0)
decoded_batches.append(shape_vae.decode_structure(sample_chunk) > 0)
decoded = torch.cat(decoded_batches, dim=0)
current_res = decoded.shape[2]
@ -349,10 +349,9 @@ class Trellis2UpsampleStage(IO.ComfyNode):
prepare_trellis_vae_for_decode(vae, shape_latent["samples"].shape)
coord_counts = shape_latent.get("coord_counts")
decoder = vae.first_stage_model.shape_dec
shape_vae = vae.first_stage_model
lr_resolution = 512
target_resolution = int(target_resolution)
decoder_dtype = next(decoder.parameters()).dtype
# Decode each sample's HR coords, then search for the largest hr_resolution
# that fits under max_tokens across all samples.
@ -361,8 +360,7 @@ class Trellis2UpsampleStage(IO.ComfyNode):
shape_latent["samples"], shape_latent["coords"], coord_counts,
)
slat = shape_norm(feats.to(device), coords_512.to(device))
slat.feats = slat.feats.to(decoder_dtype)
sample_hr_coords = [decoder.upsample(slat, upsample_times=4)]
sample_hr_coords = [shape_vae.upsample_shape(slat, upsample_times=4)]
else:
items = split_batched_sparse_latent(
shape_latent["samples"], shape_latent["coords"], coord_counts,
@ -372,8 +370,7 @@ class Trellis2UpsampleStage(IO.ComfyNode):
coords_i = coords_i.to(device).clone()
coords_i[:, 0] = 0
slat_i = shape_norm(feats_i.to(device), coords_i)
slat_i.feats = slat_i.feats.to(decoder_dtype)
sample_hr_coords.append(decoder.upsample(slat_i, upsample_times=4))
sample_hr_coords.append(shape_vae.upsample_shape(slat_i, upsample_times=4))
# Resolution search — cache the final iteration's quantized unique tensors
# so we don't recompute .unique() per sample after picking hr_resolution.