diff --git a/comfy/ldm/trellis2/flexgemm.py b/comfy/ldm/trellis2/flexgemm.py index e22f2fe98..e33b50376 100644 --- a/comfy/ldm/trellis2/flexgemm.py +++ b/comfy/ldm/trellis2/flexgemm.py @@ -1,8 +1,8 @@ -# will contain every cuda -> pytorch operation - from typing import Optional, Tuple import torch +import comfy.model_management + UINT32_SENTINEL = 0xFFFFFFFF @@ -26,9 +26,7 @@ class TorchHashMap: self.default_value = torch.tensor(default_value, dtype=torch.long, device=device) self._n = self.sorted_keys.numel() - # Chunk size for lookup_flat. At ~530M flat keys (large mesh extraction), - # the unchunked path allocates ~5 full-size int64 temporaries (4 GB each) + - # bool masks + the int32 output. Chunking caps each transient to ~CHUNK rows. + # Chunk size for lookup_flat, caps each transient to ~CHUNK rows. _LOOKUP_CHUNK = 1 << 23 # 8M rows ≈ 64 MB per int64 temp def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: @@ -119,57 +117,13 @@ def build_submanifold_neighbor_map( def get_recommended_chunk_mem( device=None, - safety_fraction: float = 0.4, + safety_fraction: float = 0.2, min_gb: float = 0.25, - max_gb: float = 8.0, + max_gb: float = 2.0, ): - - if device is None: - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - else: - device = torch.device(device) - - if device.type == 'cuda': - try: - idx = device.index if device.index is not None else 0 - free_bytes, total_bytes = torch.cuda.mem_get_info(idx) - free_gb = free_bytes / (1024 ** 3) - total_gb = total_bytes / (1024 ** 3) - - recommended = free_gb * safety_fraction - result = max(min_gb, min(recommended, max_gb)) - return result - - except Exception: - try: - idx = device.index if device.index is not None else 0 - total_gb = torch.cuda.get_device_properties(idx).total_memory / (1024 ** 3) - except Exception: - total_gb = 16.0 - - if total_gb < 12: - result = 0.5 - elif total_gb < 16: - result = 0.75 - elif total_gb < 24: - result = 1.0 - elif total_gb < 32: - result = 2.0 - elif total_gb < 48: - result = 4.0 - else: - result = 6.0 - return result - - else: - try: - import psutil - avail_gb = psutil.virtual_memory().available / (1024 ** 3) - recommended = avail_gb * safety_fraction - result = max(min_gb, min(recommended, max_gb)) - return result - except ImportError: - return min_gb + """Pick a chunk-memory budget (in GB) for sparse conv batching.""" + free_gb = comfy.model_management.get_free_memory(device) / (1024 ** 3) + return max(min_gb, min(free_gb * safety_fraction, max_gb)) def sparse_submanifold_conv3d( feats: torch.Tensor, @@ -179,24 +133,16 @@ def sparse_submanifold_conv3d( bias: Optional[torch.Tensor], neighbor_cache: Optional[torch.Tensor], dilation: tuple, - max_chunk_mem_gb: float = 6.0, - accumulate_f32: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if feats.shape[0] == 0: Co = weight.shape[0] return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None - if len(shape) == 5: - _, _, W, H, D = shape - else: - W, H, D = shape + W, H, D = shape Co, Kw, Kh, Kd, Ci = weight.shape V = Kw * Kh * Kd device = feats.device - sentinel = -1 - max_chunk_mem_gb = get_recommended_chunk_mem(device) if neighbor_cache is None: b_stride = W * H * D @@ -219,91 +165,37 @@ def sparse_submanifold_conv3d( neighbor = neighbor_cache N_pts = feats.shape[0] + sentinel = -1 - if accumulate_f32: - weight_T = weight.view(Co, V * Ci).to(torch.float32).T.contiguous() - else: - weight_T = weight.view(Co, V * Ci).to(feats.dtype).T.contiguous() + weight_T = weight.view(Co, V * Ci).T output = torch.empty(N_pts, Co, device=device, dtype=feats.dtype) - # ------------------------------------------------------------------ - # Chunk size from memory budget - # ------------------------------------------------------------------ - bytes_per_elem = 4 if accumulate_f32 else feats.element_size() - mem_per_row = V * Ci * bytes_per_elem + # Chunk size from memory budget. The dominant peak is `gathered`, of shape (chunk, V, Ci) in feats.dtype. + max_chunk_mem_gb = get_recommended_chunk_mem(device) + mem_per_row = V * Ci * feats.element_size() max_chunk_mem = max_chunk_mem_gb * (1024 ** 3) chunk_size = max(1, int(max_chunk_mem / mem_per_row)) chunk_size = min(chunk_size, N_pts) - # fp32 matmul scratch — sized to the largest chunk, reused each iteration. - chunk_buf = torch.empty(chunk_size, Co, device=device, dtype=torch.float32) if accumulate_f32 else None - - # ------------------------------------------------------------------ - # Chunked forward pass - # Each iteration: - # 1. gather (chunk, V, Ci) – memory bound - # 2. mask zero invalids – in-place, no extra alloc - # 3. reshape (chunk, V*Ci) - # 4. GEMM (chunk, V*Ci) @ (V*Ci, Co) → (chunk, Co) – cuBLAS - # written into the scratch buf (fp32) or output slice (fp16) via out= - # 5. (fp32 path) cast scratch chunk to fp16 and copy into output slice - # ------------------------------------------------------------------ for start in range(0, N_pts, chunk_size): end = min(start + chunk_size, N_pts) actual_chunk = end - start - # (chunk, V) int32 chunk_neighbor = neighbor[start:end] chunk_valid = chunk_neighbor != sentinel + # clamp(-1 -> 0) keeps invalid indices in-range so the gather is safe + chunk_idx = chunk_neighbor.clamp(min=0) - # Clamp sentinel -1 → 0 for safe indexing. No clone of the full map. - chunk_idx = chunk_neighbor.clamp(min=0).long() - - # Gather: (chunk, V, Ci). Memory-bound, single index_select. + # (chunk, V, Ci) gather, then in-place zero of invalid neighbors. gathered = feats[chunk_idx] - - # Zero invalid neighbours in-place. gathered is a fresh tensor from - # advanced indexing, so in-place mutation is safe. gathered.mul_(chunk_valid.unsqueeze(-1)) - # Reshape to (chunk, V*Ci) + # GEMM (chunk, V*Ci) @ (V*Ci, Co) -> (chunk, Co), written to output[start:end]. gathered_flat = gathered.view(actual_chunk, V * Ci) - if accumulate_f32: - gathered_flat = gathered_flat.to(torch.float32) - torch.matmul(gathered_flat, weight_T, out=chunk_buf[:actual_chunk]) - output[start:end] = chunk_buf[:actual_chunk].to(feats.dtype) - else: - torch.matmul(gathered_flat, weight_T, out=output[start:end]) + torch.matmul(gathered_flat, weight_T, out=output[start:end]) if bias is not None: output += bias.unsqueeze(0).to(output.dtype) return output, neighbor - -class Mesh: - def __init__(self, - vertices, - faces, - vertex_attrs=None - ): - self.vertices = vertices.float() - self.faces = faces.int() - self.vertex_attrs = vertex_attrs - - @property - def device(self): - return self.vertices.device - - def to(self, device, non_blocking=False): - return Mesh( - self.vertices.to(device, non_blocking=non_blocking), - self.faces.to(device, non_blocking=non_blocking), - self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None, - ) - - def cuda(self, non_blocking=False): - return self.to('cuda', non_blocking=non_blocking) - - def cpu(self): - return self.to('cpu') diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 020f68616..ec607fad3 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -1,12 +1,13 @@ import math import torch -import numpy as np import torch.nn as nn import torch.nn.functional as F from fractions import Fraction -from dataclasses import dataclass -from typing import List, Any, Dict, Optional, overload, Union, Tuple -from comfy.ldm.trellis2.flexgemm import TorchHashMap, Mesh, sparse_submanifold_conv3d +from typing import List, Any, Dict, Optional, overload, Union +import comfy.ops +from comfy.ldm.trellis2.flexgemm import TorchHashMap, sparse_submanifold_conv3d + +ops = comfy.ops.disable_weight_init def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: @@ -73,10 +74,10 @@ def sparse_conv3d_forward(self, x): out = x.replace(out) return out -class LayerNorm32(nn.LayerNorm): +class LayerNorm32(ops.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: - w = self.weight.to(x.dtype) if self.weight is not None else None - b = self.bias.to(x.dtype) if self.bias is not None else None + w = self.weight.to(x) if self.weight is not None else None + b = self.bias.to(x) if self.bias is not None else None return F.layer_norm(x, self.normalized_shape, w, b, self.eps) class SparseConvNeXtBlock3d(nn.Module): @@ -93,12 +94,13 @@ class SparseConvNeXtBlock3d(nn.Module): self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) self.conv = SparseConv3d(channels, channels, 3) self.mlp = nn.Sequential( - nn.Linear(channels, int(channels * mlp_ratio)), + ops.Linear(channels, int(channels * mlp_ratio)), nn.SiLU(inplace=True), - nn.Linear(int(channels * mlp_ratio), channels), + ops.Linear(int(channels * mlp_ratio), channels), ) def _forward(self, x): + x = x.to(dtype=self.conv.weight.dtype, device=self.conv.weight.device) h = self.conv(x) h = h.replace(self.norm(h.feats)) h = h.replace(self.mlp(h.feats)) @@ -141,7 +143,7 @@ class SparseSpatial2Channel(nn.Module): out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM])) out._scale = tuple([s * self.factor for s in x._scale]) - out._spatial_cache = x._spatial_cache + out._spatial_cache = dict(x._spatial_cache) if cache is None: x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx)) @@ -180,7 +182,7 @@ class SparseChannel2Spatial(nn.Module): out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM])) out._scale = tuple([s / self.factor for s in x._scale]) if cache is not None: # only keep cache when subdiv following it - out._spatial_cache = x._spatial_cache + out._spatial_cache = dict(x._spatial_cache) return out class SparseResBlockC2S3d(nn.Module): @@ -226,10 +228,6 @@ class SparseResBlockC2S3d(nn.Module): else: return h -@dataclass -class config: - CONV = "flexgemm" - FLEX_GEMM_HASHMAP_RATIO = 2.0 class VarLenTensor: @@ -238,18 +236,6 @@ class VarLenTensor: self.layout = layout if layout is not None else [slice(0, feats.shape[0])] self._cache = {} - @staticmethod - def layout_from_seqlen(seqlen: list) -> List[slice]: - """ - Create a layout from a tensor of sequence lengths. - """ - layout = [] - start = 0 - for l in seqlen: - layout.append(slice(start, start + l)) - start += l - return layout - @staticmethod def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor': """ @@ -376,29 +362,22 @@ class VarLenTensor: feats=feats, layout=self.layout, ) - new_tensor._cache = self._cache + # Shallow-copy so derived tensors don't share-by-reference the cache + # dict — see SparseTensor.replace for rationale. + new_tensor._cache = dict(self._cache) return new_tensor - def to_dense(self, max_length=None) -> torch.Tensor: - N = len(self) - L = max_length or self.seqlen.max().item() - spatial = self.feats.shape[1:] - idx = torch.arange(L, device=self.device).unsqueeze(0).expand(N, L) - mask = (idx < self.seqlen.unsqueeze(1)) - mapping = mask.reshape(-1).cumsum(dim=0) - 1 - dense = self.feats[mapping] - dense = dense.reshape(N, L, *spatial) - return dense, mask - def __neg__(self) -> 'VarLenTensor': return self.replace(-self.feats) def __elemwise__(self, other: Union[torch.Tensor, 'VarLenTensor'], op: callable) -> 'VarLenTensor': if isinstance(other, torch.Tensor): + # Try per-batch [B, C] -> per-token [T, C] broadcast. RuntimeError + # fires for incompatible shapes; fall through and let op() handle. try: other = torch.broadcast_to(other, self.shape) other = other[self.batch_boardcast_map] - except: + except RuntimeError: pass if isinstance(other, VarLenTensor): other = other.feats @@ -459,40 +438,6 @@ class VarLenTensor: new_tensor = VarLenTensor(feats=new_feats, layout=new_layout) return new_tensor - def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: - if isinstance(dim, int): - dim = (dim,) - - if op =='mean': - red = self.feats.mean(dim=dim, keepdim=keepdim) - elif op =='sum': - red = self.feats.sum(dim=dim, keepdim=keepdim) - elif op == 'prod': - red = self.feats.prod(dim=dim, keepdim=keepdim) - else: - raise ValueError(f"Unsupported reduce operation: {op}") - - if dim is None or 0 in dim: - return red - - red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen) - return red - - def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: - return self.reduce(op='mean', dim=dim, keepdim=keepdim) - - def sum(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: - return self.reduce(op='sum', dim=dim, keepdim=keepdim) - - def prod(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: - return self.reduce(op='prod', dim=dim, keepdim=keepdim) - - def std(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: - mean = self.mean(dim=dim, keepdim=True) - mean2 = self.replace(self.feats ** 2).mean(dim=dim, keepdim=True) - std = (mean2 - mean ** 2).sqrt() - return std - def __repr__(self) -> str: return f"VarLenTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" @@ -507,8 +452,6 @@ def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]: class SparseTensor(VarLenTensor): - SparseTensorData = None - @overload def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ... @@ -516,14 +459,6 @@ class SparseTensor(VarLenTensor): def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ... def __init__(self, *args, **kwargs): - # Lazy import of sparse tensor backend - if self.SparseTensorData is None: - import importlib - if config.CONV == 'torchsparse': - self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor - elif config.CONV == 'spconv': - self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor - method_id = 0 if len(args) != 0: method_id = 0 if isinstance(args[0], torch.Tensor) else 1 @@ -542,17 +477,10 @@ class SparseTensor(VarLenTensor): shape = kwargs['shape'] del kwargs['shape'] - if config.CONV == 'torchsparse': - self.data = self.SparseTensorData(feats, coords, **kwargs) - elif config.CONV == 'spconv': - spatial_shape = list(coords.max(0)[0] + 1) - self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs) - self.data._features = feats - else: - self.data = { - 'feats': feats, - 'coords': coords, - } + self.data = { + 'feats': feats, + 'coords': coords, + } elif method_id == 1: data, shape = args + (None,) * (2 - len(args)) if 'data' in kwargs: @@ -581,17 +509,6 @@ class SparseTensor(VarLenTensor): coords = torch.cat(coords, dim=0) return SparseTensor(feats, coords) - def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - """ - Convert a SparseTensor to list of tensors. - """ - feats_list = [] - coords_list = [] - for s in self.layout: - feats_list.append(self.feats[s]) - coords_list.append(self.coords[s]) - return feats_list, coords_list - def __len__(self) -> int: return len(self.layout) @@ -634,39 +551,19 @@ class SparseTensor(VarLenTensor): @property def feats(self) -> torch.Tensor: - if config.CONV == 'torchsparse': - return self.data.F - elif config.CONV == 'spconv': - return self.data.features - else: - return self.data['feats'] + return self.data['feats'] @feats.setter def feats(self, value: torch.Tensor): - if config.CONV == 'torchsparse': - self.data.F = value - elif config.CONV == 'spconv': - self.data.features = value - else: - self.data['feats'] = value + self.data['feats'] = value @property def coords(self) -> torch.Tensor: - if config.CONV == 'torchsparse': - return self.data.C - elif config.CONV == 'spconv': - return self.data.indices - else: - return self.data['coords'] + return self.data['coords'] @coords.setter def coords(self, value: torch.Tensor): - if config.CONV == 'torchsparse': - self.data.C = value - elif config.CONV == 'spconv': - self.data.indices = value - else: - self.data['coords'] = value + self.data['coords'] = value @property def dtype(self): @@ -773,71 +670,20 @@ class SparseTensor(VarLenTensor): return sparse_unbind(self, dim) def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': - if config.CONV == 'torchsparse': - new_data = self.SparseTensorData( - feats=feats, - coords=self.data.coords if coords is None else coords, - stride=self.data.stride, - spatial_range=self.data.spatial_range, - ) - new_data._caches = self.data._caches - elif config.CONV == 'spconv': - new_data = self.SparseTensorData( - self.data.features.reshape(self.data.features.shape[0], -1), - self.data.indices, - self.data.spatial_shape, - self.data.batch_size, - self.data.grid, - self.data.voxel_num, - self.data.indice_dict - ) - new_data._features = feats - new_data.benchmark = self.data.benchmark - new_data.benchmark_record = self.data.benchmark_record - new_data.thrust_allocator = self.data.thrust_allocator - new_data._timer = self.data._timer - new_data.force_algo = self.data.force_algo - new_data.int8_scale = self.data.int8_scale - if coords is not None: - new_data.indices = coords - else: - new_data = { - 'feats': feats, - 'coords': self.data['coords'] if coords is None else coords, - } - new_tensor = SparseTensor( + new_data = { + 'feats': feats, + 'coords': self.data['coords'] if coords is None else coords, + } + return SparseTensor( new_data, shape=torch.Size([self._shape[0]] + list(feats.shape[1:])) if self._shape is not None else None, scale=self._scale, - spatial_cache=self._spatial_cache + # Shallow-copy the cache: each derived tensor gets its own dict, so + # adding/overwriting entries on one doesn't leak to siblings. + # Cached tensors themselves are still shared by reference (safe + # because they're read-only after populate). + spatial_cache=dict(self._spatial_cache), ) - return new_tensor - - def to_dense(self) -> torch.Tensor: - if config.CONV == 'torchsparse': - return self.data.dense() - elif config.CONV == 'spconv': - return self.data.dense() - else: - spatial_shape = self.spatial_shape - ret = torch.zeros(*self.shape, *spatial_shape, dtype=self.dtype, device=self.device) - idx = [self.coords[:, 0], slice(None)] + self.coords[:, 1:].unbind(1) - ret[tuple(idx)] = self.feats - return ret - - @staticmethod - def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': - N, C = dim - x = torch.arange(aabb[0], aabb[3] + 1) - y = torch.arange(aabb[1], aabb[4] + 1) - z = torch.arange(aabb[2], aabb[5] + 1) - coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) - coords = torch.cat([ - torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), - coords.repeat(N, 1), - ], dim=1).to(dtype=torch.int32, device=device) - feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) - return SparseTensor(feats=feats, coords=coords) def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: new_cache = {} @@ -853,10 +699,12 @@ class SparseTensor(VarLenTensor): def __elemwise__(self, other: Union[torch.Tensor, VarLenTensor], op: callable) -> 'SparseTensor': if isinstance(other, torch.Tensor): + # Try per-batch [B, C] -> per-voxel [N, C] broadcast. RuntimeError + # fires for incompatible shapes; fall through and let op() handle. try: other = torch.broadcast_to(other, self.shape) other = other[self.batch_boardcast_map] - except: + except RuntimeError: pass if isinstance(other, VarLenTensor): other = other.feats @@ -901,12 +749,6 @@ class SparseTensor(VarLenTensor): new_tensor.register_spatial_cache('layout', new_layout) return new_tensor - def clear_spatial_cache(self) -> None: - """ - Clear all spatial caches. - """ - self._spatial_cache = {} - def register_spatial_cache(self, key, value) -> None: """ Register a spatial cache. @@ -961,7 +803,7 @@ def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: # allow operations.Linear inheritance class SparseLinear: - def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None, operations=nn, *args, **kwargs): + def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None, operations=ops, *args, **kwargs): class _SparseLinear(operations.Linear): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) @@ -971,24 +813,6 @@ class SparseLinear: return _SparseLinear(in_features, out_features, bias=bias, device=device, dtype=dtype, *args, **kwargs) -MIX_PRECISION_MODULES = ( - nn.Conv1d, - nn.Conv2d, - nn.Conv3d, - nn.ConvTranspose1d, - nn.ConvTranspose2d, - nn.ConvTranspose3d, - nn.Linear, - SparseConv3d, - SparseLinear, -) - - -def convert_module_to_f16(l): - if isinstance(l, MIX_PRECISION_MODULES): - for p in l.parameters(): - p.data = p.data.half() - class SparseUnetVaeDecoder(nn.Module): def __init__( self, @@ -999,17 +823,13 @@ class SparseUnetVaeDecoder(nn.Module): block_type: List[str], up_block_type: List[str], block_args: List[Dict[str, Any]], - use_fp16: bool = False, pred_subdiv: bool = True, ): super().__init__() self.out_channels = out_channels self.model_channels = model_channels self.num_blocks = num_blocks - self.use_fp16 = use_fp16 self.pred_subdiv = pred_subdiv - self.dtype = torch.float16 if use_fp16 else torch.float32 - self.low_vram = False self.output_layer = SparseLinear(model_channels[-1], out_channels) self.from_latent = SparseLinear(latent_channels, model_channels[0]) @@ -1033,17 +853,9 @@ class SparseUnetVaeDecoder(nn.Module): **block_args[i], ) ) - @property - def device(self) -> torch.device: - return next(self.parameters()).device def forward(self, x: SparseTensor, guide_subs: Optional[List[SparseTensor]] = None, return_subs: bool = False) -> SparseTensor: - - dtype = next(self.from_latent.parameters()).dtype - device = next(self.from_latent.parameters()).device - x.feats = x.feats.to(dtype).to(device) h = self.from_latent(x) - h = h.type(self.dtype) subs = [] for i, res in enumerate(self.blocks): for j, block in enumerate(res): @@ -1055,7 +867,6 @@ class SparseUnetVaeDecoder(nn.Module): h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None) else: h = block(h) - h = h.type(x.feats.dtype) h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) h = self.output_layer(h) if return_subs: @@ -1064,9 +875,7 @@ class SparseUnetVaeDecoder(nn.Module): return h def upsample(self, x: SparseTensor, upsample_times: int) -> torch.Tensor: - h = self.from_latent(x) - h = h.type(self.dtype) for i, res in enumerate(self.blocks): if i == upsample_times: return h.coords @@ -1087,13 +896,9 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder): up_block_type: List[str], block_args: List[Dict[str, Any]], voxel_margin: float = 0.5, - use_fp16: bool = False, ): self.resolution = resolution self.voxel_margin = voxel_margin - # cache for a TorchHashMap instance - self._torch_hashmap_cache = None - super().__init__( 7, model_channels, @@ -1102,22 +907,11 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder): block_type, up_block_type, block_args, - use_fp16, ) def set_resolution(self, resolution: int) -> None: self.resolution = resolution - def _build_or_get_hashmap(self, coords: torch.Tensor, grid_size: torch.Tensor): - device = coords.device - N = coords.shape[0] - _, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) - flat_keys = coords[:, 0].long() * (H * D) - flat_keys.add_(coords[:, 1].long() * D) - flat_keys.add_(coords[:, 2].long()) - values = torch.arange(N, dtype=torch.int32, device=device) - return TorchHashMap(flat_keys, values, 0xffffffff) - def forward(self, x: SparseTensor, gt_intersected: SparseTensor = None, **kwargs): decoded = super().forward(x, **kwargs) out_list = list(decoded) if isinstance(decoded, tuple) else [decoded] @@ -1125,13 +919,11 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder): vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin) intersected = h.replace(h.feats[..., 3:6] > 0) quad_lerp = h.replace(F.softplus(h.feats[..., 6:7])) - mesh = [Mesh(*flexible_dual_grid_to_mesh( + mesh = [flexible_dual_grid_to_mesh( v.coords[:, 1:], v.feats, i.feats, q.feats, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], grid_size=self.resolution, - train=False, - hashmap_builder=self._build_or_get_hashmap, - )) for v, i, q in zip(vertices, intersected, quad_lerp)] + ) for v, i, q in zip(vertices, intersected, quad_lerp)] out_list[0] = mesh return out_list[0] if len(out_list) == 1 else tuple(out_list) @@ -1140,11 +932,9 @@ def flexible_dual_grid_to_mesh( dual_vertices: torch.Tensor, intersected_flag: torch.Tensor, split_weight: Union[torch.Tensor, None], - aabb: Union[list, tuple, np.ndarray, torch.Tensor], - voxel_size: Union[float, list, tuple, np.ndarray, torch.Tensor] = None, - grid_size: Union[int, list, tuple, np.ndarray, torch.Tensor] = None, - train: bool = False, - hashmap_builder=None, # optional callable for building/caching a TorchHashMap + aabb: Union[list, tuple, torch.Tensor], + voxel_size: Union[float, list, tuple, torch.Tensor] = None, + grid_size: Union[int, list, tuple, torch.Tensor] = None, ): device = coords.device @@ -1159,46 +949,28 @@ def flexible_dual_grid_to_mesh( flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2") or flexible_dual_grid_to_mesh.quad_split_2.device != device: flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) - if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train") or flexible_dual_grid_to_mesh.quad_split_train.device != device: - flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=device, requires_grad=False) - # AABB - if isinstance(aabb, (list, tuple)): - aabb = np.array(aabb) - if isinstance(aabb, np.ndarray): - aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + aabb = torch.tensor(aabb, dtype=torch.float32, device=device) - # Voxel size if voxel_size is not None: if isinstance(voxel_size, float): - voxel_size = [voxel_size, voxel_size, voxel_size] - if isinstance(voxel_size, (list, tuple)): - voxel_size = np.array(voxel_size) - if isinstance(voxel_size, np.ndarray): - voxel_size = torch.tensor(voxel_size, dtype=torch.float32, device=coords.device) + voxel_size = [voxel_size] * 3 + voxel_size = torch.tensor(voxel_size, dtype=torch.float32, device=device) grid_size = ((aabb[1] - aabb[0]) / voxel_size).round().int() else: if isinstance(grid_size, int): - grid_size = [grid_size, grid_size, grid_size] - if isinstance(grid_size, (list, tuple)): - grid_size = np.array(grid_size) - if isinstance(grid_size, np.ndarray): - grid_size = torch.tensor(grid_size, dtype=torch.int32, device=coords.device) + grid_size = [grid_size] * 3 + grid_size = torch.tensor(grid_size, dtype=torch.int32, device=device) voxel_size = (aabb[1] - aabb[0]) / grid_size # Extract mesh N = dual_vertices.shape[0] - - if hashmap_builder is None: - device = coords.device - _, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) - flat_keys = coords[:, 0].long() * (H * D) - flat_keys.add_(coords[:, 1].long() * D) - flat_keys.add_(coords[:, 2].long()) - values = torch.arange(N, dtype=torch.long, device=device) - torch_hashmap = TorchHashMap(flat_keys, values, 0xffffffff) - else: - torch_hashmap = hashmap_builder(coords, grid_size) + _, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) + flat_keys = coords[:, 0].long() * (H * D) + flat_keys.add_(coords[:, 1].long() * D) + flat_keys.add_(coords[:, 2].long()) + values = torch.arange(N, dtype=torch.int32, device=device) + torch_hashmap = TorchHashMap(flat_keys, values, 0xffffffff) # Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3] n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,) @@ -1253,55 +1025,34 @@ class ChannelLayerNorm32(LayerNorm32): return x class UpsampleBlock3d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - mode = "conv", - ): - assert mode in ["conv", "nearest"], f"Invalid mode {mode}" - + def __init__(self, in_channels: int, out_channels: int): super().__init__() self.in_channels = in_channels self.out_channels = out_channels - - if mode == "conv": - self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) - elif mode == "nearest": - assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" + self.conv = ops.Conv3d(in_channels, out_channels * 8, 3, padding=1) def forward(self, x: torch.Tensor) -> torch.Tensor: - if hasattr(self, "conv"): - x = self.conv(x) - return pixel_shuffle_3d(x, 2) - else: - return F.interpolate(x, scale_factor=2, mode="nearest") - -def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: - return ChannelLayerNorm32(*args, **kwargs) + return pixel_shuffle_3d(self.conv(x), 2) class ResBlock3d(nn.Module): def __init__( self, channels: int, out_channels: Optional[int] = None, - norm_type = "layer", ): super().__init__() self.channels = channels self.out_channels = out_channels or channels - self.norm1 = norm_layer(norm_type, channels) - self.norm2 = norm_layer(norm_type, self.out_channels) - self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) - self.conv2 = nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1) - self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() + self.norm1 = ChannelLayerNorm32(channels) + self.norm2 = ChannelLayerNorm32(self.out_channels) + self.conv1 = ops.Conv3d(channels, self.out_channels, 3, padding=1) + self.conv2 = ops.Conv3d(self.out_channels, self.out_channels, 3, padding=1) + self.skip_connection = ops.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.norm1(x) h = F.silu(h) - dtype = next(self.conv1.parameters()).dtype - h = h.to(dtype) h = self.conv1(h) h = self.norm2(h) h = F.silu(h) @@ -1318,8 +1069,6 @@ class SparseStructureDecoder(nn.Module): num_res_blocks: int, channels: List[int], num_res_blocks_middle: int = 2, - norm_type = "layer", - use_fp16: bool = True, ): super().__init__() self.out_channels = out_channels @@ -1327,11 +1076,8 @@ class SparseStructureDecoder(nn.Module): self.num_res_blocks = num_res_blocks self.channels = channels self.num_res_blocks_middle = num_res_blocks_middle - self.norm_type = norm_type - self.use_fp16 = use_fp16 - self.dtype = torch.float16 if use_fp16 else torch.float32 - self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) + self.input_layer = ops.Conv3d(latent_channels, channels[0], 3, padding=1) self.middle_block = nn.Sequential(*[ ResBlock3d(channels[0], channels[0]) @@ -1350,92 +1096,82 @@ class SparseStructureDecoder(nn.Module): ) self.out_layer = nn.Sequential( - norm_layer(norm_type, channels[-1]), + ChannelLayerNorm32(channels[-1]), nn.SiLU(), - nn.Conv3d(channels[-1], out_channels, 3, padding=1) + ops.Conv3d(channels[-1], out_channels, 3, padding=1) ) - if use_fp16: - self.convert_to_fp16() - - def device(self) -> torch.device: - return next(self.parameters()).device - - def convert_to_fp16(self) -> None: - self.use_fp16 = True - self.dtype = torch.float16 - self.blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - def forward(self, x: torch.Tensor) -> torch.Tensor: - dtype = next(self.input_layer.parameters()).dtype - x = x.to(dtype) h = self.input_layer(x) - - h = h.type(self.dtype) h = self.middle_block(h) for block in self.blocks: h = block(h) - - h = h.type(x.dtype) h = self.out_layer(h) return h -class Vae(nn.Module): - def __init__(self, init_txt_model, init_txt_model_only, operations=None): + +class ShapeVae(nn.Module): + """Decoder bundle from the Trellis2 shape checkpoint: structure decoder + (32^3 SS latent → 64^3 dense occupancy) and shape decoder (sparse latent → + mesh + per-stage subdivisions).""" + + def __init__(self): super().__init__() - operations = operations or torch.nn - if init_txt_model or init_txt_model_only: - self.txt_dec = SparseUnetVaeDecoder( - out_channels=6, - model_channels=[1024, 512, 256, 128, 64], - latent_channels=32, - num_blocks=[4, 16, 8, 4, 0], - block_type=["SparseConvNeXtBlock3d"] * 5, - up_block_type=["SparseResBlockC2S3d"] * 4, - block_args=[{}, {}, {}, {}, {}], - pred_subdiv=False - ) - - if not init_txt_model_only: - self.shape_dec = FlexiDualGridVaeDecoder( - resolution=256, - model_channels=[1024, 512, 256, 128, 64], - latent_channels=32, - num_blocks=[4, 16, 8, 4, 0], - block_type=["SparseConvNeXtBlock3d"] * 5, - up_block_type=["SparseResBlockC2S3d"] * 4, - block_args=[{}, {}, {}, {}, {}], - ) - - self.struct_dec = SparseStructureDecoder( - out_channels=1, - latent_channels=8, - num_res_blocks=2, - num_res_blocks_middle=2, - channels=[512, 128, 32], - ) + self.shape_dec = FlexiDualGridVaeDecoder( + resolution=256, + model_channels=[1024, 512, 256, 128, 64], + latent_channels=32, + num_blocks=[4, 16, 8, 4, 0], + block_type=["SparseConvNeXtBlock3d"] * 5, + up_block_type=["SparseResBlockC2S3d"] * 4, + block_args=[{}, {}, {}, {}, {}], + ) + self.struct_dec = SparseStructureDecoder( + out_channels=1, + latent_channels=8, + num_res_blocks=2, + num_res_blocks_middle=2, + channels=[512, 128, 32], + ) self.register_buffer("resolution", torch.tensor(1024.0), persistent=False) - @torch.no_grad() - def decode_shape_slat(self, slat, resolution: int): + def decode_structure(self, x: torch.Tensor) -> torch.Tensor: + weight = self.struct_dec.input_layer.weight + x = x.to(dtype=weight.dtype, device=weight.device) + return self.struct_dec(x) + + def decode_shape_slat(self, slat: 'SparseTensor', resolution: int): + weight = self.shape_dec.from_latent.weight + slat = slat.to(dtype=weight.dtype, device=weight.device) self.shape_dec.set_resolution(resolution) return self.shape_dec(slat, return_subs=True) - @torch.no_grad() - def decode_tex_slat(self, slat, subs): - if self.txt_dec is None: - raise ValueError("Checkpoint doesn't include texture model") + def upsample_shape(self, slat: 'SparseTensor', upsample_times: int) -> torch.Tensor: + weight = self.shape_dec.from_latent.weight + slat = slat.to(dtype=weight.dtype, device=weight.device) + return self.shape_dec.upsample(slat, upsample_times) + + +class TextureVae(nn.Module): + """Decoder bundle from the Trellis2 texture checkpoint: sparse 3D + per-voxel color decoder, guided by subdivisions from a prior shape decode.""" + + def __init__(self): + super().__init__() + self.txt_dec = SparseUnetVaeDecoder( + out_channels=6, + model_channels=[1024, 512, 256, 128, 64], + latent_channels=32, + num_blocks=[4, 16, 8, 4, 0], + block_type=["SparseConvNeXtBlock3d"] * 5, + up_block_type=["SparseResBlockC2S3d"] * 4, + block_args=[{}, {}, {}, {}, {}], + pred_subdiv=False, + ) + self.register_buffer("resolution", torch.tensor(1024.0), persistent=False) + + def decode_tex_slat(self, slat: 'SparseTensor', subs): + weight = self.txt_dec.from_latent.weight + slat = slat.to(dtype=weight.dtype, device=weight.device) return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5 - # shouldn't be called (placeholder) - @torch.no_grad() - def decode( - self, - shape_slat: SparseTensor, - tex_slat: SparseTensor, - resolution: int, - ): - meshes, subs = self.decode_shape_slat(shape_slat, resolution) - tex_voxels = self.decode_tex_slat(tex_slat, subs) - return tex_voxels diff --git a/comfy/sd.py b/comfy/sd.py index 8bbf77ce9..309fb7763 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -529,18 +529,18 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 - elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd or "txt_dec.blocks.3.4.conv2.weight" in sd: # trellis2 or trellis2 texture only - init_txt_model = False - init_txt_model_only = False - if "shape_dec.blocks.1.16.to_subdiv.weight" not in sd: - init_txt_model_only = True - if "txt_dec.blocks.1.16.norm1.weight" in sd: - init_txt_model = True + elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd: # trellis2 shape vae (struct_dec + shape_dec) self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] # TODO self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) - self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model, init_txt_model_only= init_txt_model_only) + self.first_stage_model = comfy.ldm.trellis2.vae.ShapeVae() + elif "txt_dec.blocks.3.4.conv2.weight" in sd: # trellis2 texture vae + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + # TODO + self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.first_stage_model = comfy.ldm.trellis2.vae.TextureVae() elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 7bc22074e..266f67043 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -205,8 +205,8 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): coords_list = [stage_tensor.coords for stage_tensor in stage_tensors] subs.append(SparseTensor.from_tensor_list(feats_list, coords_list)) - face_list = [m.faces for m in mesh] - vert_list = [m.vertices for m in mesh] + vert_list = [v.float() for v, f in mesh] + face_list = [f.int() for v, f in mesh] if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list): mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list)) else: @@ -286,12 +286,12 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): sample_tensor = samples["samples"] sample_tensor = sample_tensor[:, :8] batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape) - decoder = vae.first_stage_model.struct_dec + shape_vae = vae.first_stage_model load_device = comfy.model_management.get_torch_device() decoded_batches = [] for start in range(0, sample_tensor.shape[0], batch_number): sample_chunk = sample_tensor[start:start + batch_number].to(load_device) - decoded_batches.append(decoder(sample_chunk) > 0) + decoded_batches.append(shape_vae.decode_structure(sample_chunk) > 0) decoded = torch.cat(decoded_batches, dim=0) current_res = decoded.shape[2] @@ -349,10 +349,9 @@ class Trellis2UpsampleStage(IO.ComfyNode): prepare_trellis_vae_for_decode(vae, shape_latent["samples"].shape) coord_counts = shape_latent.get("coord_counts") - decoder = vae.first_stage_model.shape_dec + shape_vae = vae.first_stage_model lr_resolution = 512 target_resolution = int(target_resolution) - decoder_dtype = next(decoder.parameters()).dtype # Decode each sample's HR coords, then search for the largest hr_resolution # that fits under max_tokens across all samples. @@ -361,8 +360,7 @@ class Trellis2UpsampleStage(IO.ComfyNode): shape_latent["samples"], shape_latent["coords"], coord_counts, ) slat = shape_norm(feats.to(device), coords_512.to(device)) - slat.feats = slat.feats.to(decoder_dtype) - sample_hr_coords = [decoder.upsample(slat, upsample_times=4)] + sample_hr_coords = [shape_vae.upsample_shape(slat, upsample_times=4)] else: items = split_batched_sparse_latent( shape_latent["samples"], shape_latent["coords"], coord_counts, @@ -372,8 +370,7 @@ class Trellis2UpsampleStage(IO.ComfyNode): coords_i = coords_i.to(device).clone() coords_i[:, 0] = 0 slat_i = shape_norm(feats_i.to(device), coords_i) - slat_i.feats = slat_i.feats.to(decoder_dtype) - sample_hr_coords.append(decoder.upsample(slat_i, upsample_times=4)) + sample_hr_coords.append(shape_vae.upsample_shape(slat_i, upsample_times=4)) # Resolution search — cache the final iteration's quantized unique tensors # so we don't recompute .unique() per sample after picking hr_resolution.