mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
Cleanup VAE
This commit is contained in:
parent
4585a731c1
commit
3edbf7c4a7
@ -1,8 +1,8 @@
|
||||
# will contain every cuda -> pytorch operation
|
||||
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
UINT32_SENTINEL = 0xFFFFFFFF
|
||||
|
||||
|
||||
@ -26,9 +26,7 @@ class TorchHashMap:
|
||||
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
|
||||
self._n = self.sorted_keys.numel()
|
||||
|
||||
# Chunk size for lookup_flat. At ~530M flat keys (large mesh extraction),
|
||||
# the unchunked path allocates ~5 full-size int64 temporaries (4 GB each) +
|
||||
# bool masks + the int32 output. Chunking caps each transient to ~CHUNK rows.
|
||||
# Chunk size for lookup_flat, caps each transient to ~CHUNK rows.
|
||||
_LOOKUP_CHUNK = 1 << 23 # 8M rows ≈ 64 MB per int64 temp
|
||||
|
||||
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
|
||||
@ -119,57 +117,13 @@ def build_submanifold_neighbor_map(
|
||||
|
||||
def get_recommended_chunk_mem(
|
||||
device=None,
|
||||
safety_fraction: float = 0.4,
|
||||
safety_fraction: float = 0.2,
|
||||
min_gb: float = 0.25,
|
||||
max_gb: float = 8.0,
|
||||
max_gb: float = 2.0,
|
||||
):
|
||||
|
||||
if device is None:
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
device = torch.device(device)
|
||||
|
||||
if device.type == 'cuda':
|
||||
try:
|
||||
idx = device.index if device.index is not None else 0
|
||||
free_bytes, total_bytes = torch.cuda.mem_get_info(idx)
|
||||
free_gb = free_bytes / (1024 ** 3)
|
||||
total_gb = total_bytes / (1024 ** 3)
|
||||
|
||||
recommended = free_gb * safety_fraction
|
||||
result = max(min_gb, min(recommended, max_gb))
|
||||
return result
|
||||
|
||||
except Exception:
|
||||
try:
|
||||
idx = device.index if device.index is not None else 0
|
||||
total_gb = torch.cuda.get_device_properties(idx).total_memory / (1024 ** 3)
|
||||
except Exception:
|
||||
total_gb = 16.0
|
||||
|
||||
if total_gb < 12:
|
||||
result = 0.5
|
||||
elif total_gb < 16:
|
||||
result = 0.75
|
||||
elif total_gb < 24:
|
||||
result = 1.0
|
||||
elif total_gb < 32:
|
||||
result = 2.0
|
||||
elif total_gb < 48:
|
||||
result = 4.0
|
||||
else:
|
||||
result = 6.0
|
||||
return result
|
||||
|
||||
else:
|
||||
try:
|
||||
import psutil
|
||||
avail_gb = psutil.virtual_memory().available / (1024 ** 3)
|
||||
recommended = avail_gb * safety_fraction
|
||||
result = max(min_gb, min(recommended, max_gb))
|
||||
return result
|
||||
except ImportError:
|
||||
return min_gb
|
||||
"""Pick a chunk-memory budget (in GB) for sparse conv batching."""
|
||||
free_gb = comfy.model_management.get_free_memory(device) / (1024 ** 3)
|
||||
return max(min_gb, min(free_gb * safety_fraction, max_gb))
|
||||
|
||||
def sparse_submanifold_conv3d(
|
||||
feats: torch.Tensor,
|
||||
@ -179,24 +133,16 @@ def sparse_submanifold_conv3d(
|
||||
bias: Optional[torch.Tensor],
|
||||
neighbor_cache: Optional[torch.Tensor],
|
||||
dilation: tuple,
|
||||
max_chunk_mem_gb: float = 6.0,
|
||||
accumulate_f32: bool = True,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
if feats.shape[0] == 0:
|
||||
Co = weight.shape[0]
|
||||
return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None
|
||||
|
||||
if len(shape) == 5:
|
||||
_, _, W, H, D = shape
|
||||
else:
|
||||
W, H, D = shape
|
||||
W, H, D = shape
|
||||
|
||||
Co, Kw, Kh, Kd, Ci = weight.shape
|
||||
V = Kw * Kh * Kd
|
||||
device = feats.device
|
||||
sentinel = -1
|
||||
max_chunk_mem_gb = get_recommended_chunk_mem(device)
|
||||
|
||||
if neighbor_cache is None:
|
||||
b_stride = W * H * D
|
||||
@ -219,91 +165,37 @@ def sparse_submanifold_conv3d(
|
||||
neighbor = neighbor_cache
|
||||
|
||||
N_pts = feats.shape[0]
|
||||
sentinel = -1
|
||||
|
||||
if accumulate_f32:
|
||||
weight_T = weight.view(Co, V * Ci).to(torch.float32).T.contiguous()
|
||||
else:
|
||||
weight_T = weight.view(Co, V * Ci).to(feats.dtype).T.contiguous()
|
||||
weight_T = weight.view(Co, V * Ci).T
|
||||
|
||||
output = torch.empty(N_pts, Co, device=device, dtype=feats.dtype)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Chunk size from memory budget
|
||||
# ------------------------------------------------------------------
|
||||
bytes_per_elem = 4 if accumulate_f32 else feats.element_size()
|
||||
mem_per_row = V * Ci * bytes_per_elem
|
||||
# Chunk size from memory budget. The dominant peak is `gathered`, of shape (chunk, V, Ci) in feats.dtype.
|
||||
max_chunk_mem_gb = get_recommended_chunk_mem(device)
|
||||
mem_per_row = V * Ci * feats.element_size()
|
||||
max_chunk_mem = max_chunk_mem_gb * (1024 ** 3)
|
||||
chunk_size = max(1, int(max_chunk_mem / mem_per_row))
|
||||
chunk_size = min(chunk_size, N_pts)
|
||||
|
||||
# fp32 matmul scratch — sized to the largest chunk, reused each iteration.
|
||||
chunk_buf = torch.empty(chunk_size, Co, device=device, dtype=torch.float32) if accumulate_f32 else None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Chunked forward pass
|
||||
# Each iteration:
|
||||
# 1. gather (chunk, V, Ci) – memory bound
|
||||
# 2. mask zero invalids – in-place, no extra alloc
|
||||
# 3. reshape (chunk, V*Ci)
|
||||
# 4. GEMM (chunk, V*Ci) @ (V*Ci, Co) → (chunk, Co) – cuBLAS
|
||||
# written into the scratch buf (fp32) or output slice (fp16) via out=
|
||||
# 5. (fp32 path) cast scratch chunk to fp16 and copy into output slice
|
||||
# ------------------------------------------------------------------
|
||||
for start in range(0, N_pts, chunk_size):
|
||||
end = min(start + chunk_size, N_pts)
|
||||
actual_chunk = end - start
|
||||
|
||||
# (chunk, V) int32
|
||||
chunk_neighbor = neighbor[start:end]
|
||||
chunk_valid = chunk_neighbor != sentinel
|
||||
# clamp(-1 -> 0) keeps invalid indices in-range so the gather is safe
|
||||
chunk_idx = chunk_neighbor.clamp(min=0)
|
||||
|
||||
# Clamp sentinel -1 → 0 for safe indexing. No clone of the full map.
|
||||
chunk_idx = chunk_neighbor.clamp(min=0).long()
|
||||
|
||||
# Gather: (chunk, V, Ci). Memory-bound, single index_select.
|
||||
# (chunk, V, Ci) gather, then in-place zero of invalid neighbors.
|
||||
gathered = feats[chunk_idx]
|
||||
|
||||
# Zero invalid neighbours in-place. gathered is a fresh tensor from
|
||||
# advanced indexing, so in-place mutation is safe.
|
||||
gathered.mul_(chunk_valid.unsqueeze(-1))
|
||||
|
||||
# Reshape to (chunk, V*Ci)
|
||||
# GEMM (chunk, V*Ci) @ (V*Ci, Co) -> (chunk, Co), written to output[start:end].
|
||||
gathered_flat = gathered.view(actual_chunk, V * Ci)
|
||||
if accumulate_f32:
|
||||
gathered_flat = gathered_flat.to(torch.float32)
|
||||
torch.matmul(gathered_flat, weight_T, out=chunk_buf[:actual_chunk])
|
||||
output[start:end] = chunk_buf[:actual_chunk].to(feats.dtype)
|
||||
else:
|
||||
torch.matmul(gathered_flat, weight_T, out=output[start:end])
|
||||
torch.matmul(gathered_flat, weight_T, out=output[start:end])
|
||||
|
||||
if bias is not None:
|
||||
output += bias.unsqueeze(0).to(output.dtype)
|
||||
|
||||
return output, neighbor
|
||||
|
||||
class Mesh:
|
||||
def __init__(self,
|
||||
vertices,
|
||||
faces,
|
||||
vertex_attrs=None
|
||||
):
|
||||
self.vertices = vertices.float()
|
||||
self.faces = faces.int()
|
||||
self.vertex_attrs = vertex_attrs
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.vertices.device
|
||||
|
||||
def to(self, device, non_blocking=False):
|
||||
return Mesh(
|
||||
self.vertices.to(device, non_blocking=non_blocking),
|
||||
self.faces.to(device, non_blocking=non_blocking),
|
||||
self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None,
|
||||
)
|
||||
|
||||
def cuda(self, non_blocking=False):
|
||||
return self.to('cuda', non_blocking=non_blocking)
|
||||
|
||||
def cpu(self):
|
||||
return self.to('cpu')
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fractions import Fraction
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Any, Dict, Optional, overload, Union, Tuple
|
||||
from comfy.ldm.trellis2.flexgemm import TorchHashMap, Mesh, sparse_submanifold_conv3d
|
||||
from typing import List, Any, Dict, Optional, overload, Union
|
||||
import comfy.ops
|
||||
from comfy.ldm.trellis2.flexgemm import TorchHashMap, sparse_submanifold_conv3d
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
|
||||
@ -73,10 +74,10 @@ def sparse_conv3d_forward(self, x):
|
||||
out = x.replace(out)
|
||||
return out
|
||||
|
||||
class LayerNorm32(nn.LayerNorm):
|
||||
class LayerNorm32(ops.LayerNorm):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
w = self.weight.to(x.dtype) if self.weight is not None else None
|
||||
b = self.bias.to(x.dtype) if self.bias is not None else None
|
||||
w = self.weight.to(x) if self.weight is not None else None
|
||||
b = self.bias.to(x) if self.bias is not None else None
|
||||
return F.layer_norm(x, self.normalized_shape, w, b, self.eps)
|
||||
|
||||
class SparseConvNeXtBlock3d(nn.Module):
|
||||
@ -93,12 +94,13 @@ class SparseConvNeXtBlock3d(nn.Module):
|
||||
self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
||||
self.conv = SparseConv3d(channels, channels, 3)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(channels, int(channels * mlp_ratio)),
|
||||
ops.Linear(channels, int(channels * mlp_ratio)),
|
||||
nn.SiLU(inplace=True),
|
||||
nn.Linear(int(channels * mlp_ratio), channels),
|
||||
ops.Linear(int(channels * mlp_ratio), channels),
|
||||
)
|
||||
|
||||
def _forward(self, x):
|
||||
x = x.to(dtype=self.conv.weight.dtype, device=self.conv.weight.device)
|
||||
h = self.conv(x)
|
||||
h = h.replace(self.norm(h.feats))
|
||||
h = h.replace(self.mlp(h.feats))
|
||||
@ -141,7 +143,7 @@ class SparseSpatial2Channel(nn.Module):
|
||||
|
||||
out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM]))
|
||||
out._scale = tuple([s * self.factor for s in x._scale])
|
||||
out._spatial_cache = x._spatial_cache
|
||||
out._spatial_cache = dict(x._spatial_cache)
|
||||
|
||||
if cache is None:
|
||||
x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx))
|
||||
@ -180,7 +182,7 @@ class SparseChannel2Spatial(nn.Module):
|
||||
out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM]))
|
||||
out._scale = tuple([s / self.factor for s in x._scale])
|
||||
if cache is not None: # only keep cache when subdiv following it
|
||||
out._spatial_cache = x._spatial_cache
|
||||
out._spatial_cache = dict(x._spatial_cache)
|
||||
return out
|
||||
|
||||
class SparseResBlockC2S3d(nn.Module):
|
||||
@ -226,10 +228,6 @@ class SparseResBlockC2S3d(nn.Module):
|
||||
else:
|
||||
return h
|
||||
|
||||
@dataclass
|
||||
class config:
|
||||
CONV = "flexgemm"
|
||||
FLEX_GEMM_HASHMAP_RATIO = 2.0
|
||||
|
||||
class VarLenTensor:
|
||||
|
||||
@ -238,18 +236,6 @@ class VarLenTensor:
|
||||
self.layout = layout if layout is not None else [slice(0, feats.shape[0])]
|
||||
self._cache = {}
|
||||
|
||||
@staticmethod
|
||||
def layout_from_seqlen(seqlen: list) -> List[slice]:
|
||||
"""
|
||||
Create a layout from a tensor of sequence lengths.
|
||||
"""
|
||||
layout = []
|
||||
start = 0
|
||||
for l in seqlen:
|
||||
layout.append(slice(start, start + l))
|
||||
start += l
|
||||
return layout
|
||||
|
||||
@staticmethod
|
||||
def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor':
|
||||
"""
|
||||
@ -376,29 +362,22 @@ class VarLenTensor:
|
||||
feats=feats,
|
||||
layout=self.layout,
|
||||
)
|
||||
new_tensor._cache = self._cache
|
||||
# Shallow-copy so derived tensors don't share-by-reference the cache
|
||||
# dict — see SparseTensor.replace for rationale.
|
||||
new_tensor._cache = dict(self._cache)
|
||||
return new_tensor
|
||||
|
||||
def to_dense(self, max_length=None) -> torch.Tensor:
|
||||
N = len(self)
|
||||
L = max_length or self.seqlen.max().item()
|
||||
spatial = self.feats.shape[1:]
|
||||
idx = torch.arange(L, device=self.device).unsqueeze(0).expand(N, L)
|
||||
mask = (idx < self.seqlen.unsqueeze(1))
|
||||
mapping = mask.reshape(-1).cumsum(dim=0) - 1
|
||||
dense = self.feats[mapping]
|
||||
dense = dense.reshape(N, L, *spatial)
|
||||
return dense, mask
|
||||
|
||||
def __neg__(self) -> 'VarLenTensor':
|
||||
return self.replace(-self.feats)
|
||||
|
||||
def __elemwise__(self, other: Union[torch.Tensor, 'VarLenTensor'], op: callable) -> 'VarLenTensor':
|
||||
if isinstance(other, torch.Tensor):
|
||||
# Try per-batch [B, C] -> per-token [T, C] broadcast. RuntimeError
|
||||
# fires for incompatible shapes; fall through and let op() handle.
|
||||
try:
|
||||
other = torch.broadcast_to(other, self.shape)
|
||||
other = other[self.batch_boardcast_map]
|
||||
except:
|
||||
except RuntimeError:
|
||||
pass
|
||||
if isinstance(other, VarLenTensor):
|
||||
other = other.feats
|
||||
@ -459,40 +438,6 @@ class VarLenTensor:
|
||||
new_tensor = VarLenTensor(feats=new_feats, layout=new_layout)
|
||||
return new_tensor
|
||||
|
||||
def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
|
||||
if isinstance(dim, int):
|
||||
dim = (dim,)
|
||||
|
||||
if op =='mean':
|
||||
red = self.feats.mean(dim=dim, keepdim=keepdim)
|
||||
elif op =='sum':
|
||||
red = self.feats.sum(dim=dim, keepdim=keepdim)
|
||||
elif op == 'prod':
|
||||
red = self.feats.prod(dim=dim, keepdim=keepdim)
|
||||
else:
|
||||
raise ValueError(f"Unsupported reduce operation: {op}")
|
||||
|
||||
if dim is None or 0 in dim:
|
||||
return red
|
||||
|
||||
red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen)
|
||||
return red
|
||||
|
||||
def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
|
||||
return self.reduce(op='mean', dim=dim, keepdim=keepdim)
|
||||
|
||||
def sum(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
|
||||
return self.reduce(op='sum', dim=dim, keepdim=keepdim)
|
||||
|
||||
def prod(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
|
||||
return self.reduce(op='prod', dim=dim, keepdim=keepdim)
|
||||
|
||||
def std(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
|
||||
mean = self.mean(dim=dim, keepdim=True)
|
||||
mean2 = self.replace(self.feats ** 2).mean(dim=dim, keepdim=True)
|
||||
std = (mean2 - mean ** 2).sqrt()
|
||||
return std
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"VarLenTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})"
|
||||
|
||||
@ -507,8 +452,6 @@ def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]:
|
||||
|
||||
class SparseTensor(VarLenTensor):
|
||||
|
||||
SparseTensorData = None
|
||||
|
||||
@overload
|
||||
def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ...
|
||||
|
||||
@ -516,14 +459,6 @@ class SparseTensor(VarLenTensor):
|
||||
def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ...
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Lazy import of sparse tensor backend
|
||||
if self.SparseTensorData is None:
|
||||
import importlib
|
||||
if config.CONV == 'torchsparse':
|
||||
self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor
|
||||
elif config.CONV == 'spconv':
|
||||
self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
|
||||
|
||||
method_id = 0
|
||||
if len(args) != 0:
|
||||
method_id = 0 if isinstance(args[0], torch.Tensor) else 1
|
||||
@ -542,17 +477,10 @@ class SparseTensor(VarLenTensor):
|
||||
shape = kwargs['shape']
|
||||
del kwargs['shape']
|
||||
|
||||
if config.CONV == 'torchsparse':
|
||||
self.data = self.SparseTensorData(feats, coords, **kwargs)
|
||||
elif config.CONV == 'spconv':
|
||||
spatial_shape = list(coords.max(0)[0] + 1)
|
||||
self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs)
|
||||
self.data._features = feats
|
||||
else:
|
||||
self.data = {
|
||||
'feats': feats,
|
||||
'coords': coords,
|
||||
}
|
||||
self.data = {
|
||||
'feats': feats,
|
||||
'coords': coords,
|
||||
}
|
||||
elif method_id == 1:
|
||||
data, shape = args + (None,) * (2 - len(args))
|
||||
if 'data' in kwargs:
|
||||
@ -581,17 +509,6 @@ class SparseTensor(VarLenTensor):
|
||||
coords = torch.cat(coords, dim=0)
|
||||
return SparseTensor(feats, coords)
|
||||
|
||||
def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
"""
|
||||
Convert a SparseTensor to list of tensors.
|
||||
"""
|
||||
feats_list = []
|
||||
coords_list = []
|
||||
for s in self.layout:
|
||||
feats_list.append(self.feats[s])
|
||||
coords_list.append(self.coords[s])
|
||||
return feats_list, coords_list
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.layout)
|
||||
|
||||
@ -634,39 +551,19 @@ class SparseTensor(VarLenTensor):
|
||||
|
||||
@property
|
||||
def feats(self) -> torch.Tensor:
|
||||
if config.CONV == 'torchsparse':
|
||||
return self.data.F
|
||||
elif config.CONV == 'spconv':
|
||||
return self.data.features
|
||||
else:
|
||||
return self.data['feats']
|
||||
return self.data['feats']
|
||||
|
||||
@feats.setter
|
||||
def feats(self, value: torch.Tensor):
|
||||
if config.CONV == 'torchsparse':
|
||||
self.data.F = value
|
||||
elif config.CONV == 'spconv':
|
||||
self.data.features = value
|
||||
else:
|
||||
self.data['feats'] = value
|
||||
self.data['feats'] = value
|
||||
|
||||
@property
|
||||
def coords(self) -> torch.Tensor:
|
||||
if config.CONV == 'torchsparse':
|
||||
return self.data.C
|
||||
elif config.CONV == 'spconv':
|
||||
return self.data.indices
|
||||
else:
|
||||
return self.data['coords']
|
||||
return self.data['coords']
|
||||
|
||||
@coords.setter
|
||||
def coords(self, value: torch.Tensor):
|
||||
if config.CONV == 'torchsparse':
|
||||
self.data.C = value
|
||||
elif config.CONV == 'spconv':
|
||||
self.data.indices = value
|
||||
else:
|
||||
self.data['coords'] = value
|
||||
self.data['coords'] = value
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
@ -773,71 +670,20 @@ class SparseTensor(VarLenTensor):
|
||||
return sparse_unbind(self, dim)
|
||||
|
||||
def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
|
||||
if config.CONV == 'torchsparse':
|
||||
new_data = self.SparseTensorData(
|
||||
feats=feats,
|
||||
coords=self.data.coords if coords is None else coords,
|
||||
stride=self.data.stride,
|
||||
spatial_range=self.data.spatial_range,
|
||||
)
|
||||
new_data._caches = self.data._caches
|
||||
elif config.CONV == 'spconv':
|
||||
new_data = self.SparseTensorData(
|
||||
self.data.features.reshape(self.data.features.shape[0], -1),
|
||||
self.data.indices,
|
||||
self.data.spatial_shape,
|
||||
self.data.batch_size,
|
||||
self.data.grid,
|
||||
self.data.voxel_num,
|
||||
self.data.indice_dict
|
||||
)
|
||||
new_data._features = feats
|
||||
new_data.benchmark = self.data.benchmark
|
||||
new_data.benchmark_record = self.data.benchmark_record
|
||||
new_data.thrust_allocator = self.data.thrust_allocator
|
||||
new_data._timer = self.data._timer
|
||||
new_data.force_algo = self.data.force_algo
|
||||
new_data.int8_scale = self.data.int8_scale
|
||||
if coords is not None:
|
||||
new_data.indices = coords
|
||||
else:
|
||||
new_data = {
|
||||
'feats': feats,
|
||||
'coords': self.data['coords'] if coords is None else coords,
|
||||
}
|
||||
new_tensor = SparseTensor(
|
||||
new_data = {
|
||||
'feats': feats,
|
||||
'coords': self.data['coords'] if coords is None else coords,
|
||||
}
|
||||
return SparseTensor(
|
||||
new_data,
|
||||
shape=torch.Size([self._shape[0]] + list(feats.shape[1:])) if self._shape is not None else None,
|
||||
scale=self._scale,
|
||||
spatial_cache=self._spatial_cache
|
||||
# Shallow-copy the cache: each derived tensor gets its own dict, so
|
||||
# adding/overwriting entries on one doesn't leak to siblings.
|
||||
# Cached tensors themselves are still shared by reference (safe
|
||||
# because they're read-only after populate).
|
||||
spatial_cache=dict(self._spatial_cache),
|
||||
)
|
||||
return new_tensor
|
||||
|
||||
def to_dense(self) -> torch.Tensor:
|
||||
if config.CONV == 'torchsparse':
|
||||
return self.data.dense()
|
||||
elif config.CONV == 'spconv':
|
||||
return self.data.dense()
|
||||
else:
|
||||
spatial_shape = self.spatial_shape
|
||||
ret = torch.zeros(*self.shape, *spatial_shape, dtype=self.dtype, device=self.device)
|
||||
idx = [self.coords[:, 0], slice(None)] + self.coords[:, 1:].unbind(1)
|
||||
ret[tuple(idx)] = self.feats
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
|
||||
N, C = dim
|
||||
x = torch.arange(aabb[0], aabb[3] + 1)
|
||||
y = torch.arange(aabb[1], aabb[4] + 1)
|
||||
z = torch.arange(aabb[2], aabb[5] + 1)
|
||||
coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
|
||||
coords = torch.cat([
|
||||
torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
|
||||
coords.repeat(N, 1),
|
||||
], dim=1).to(dtype=torch.int32, device=device)
|
||||
feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
|
||||
return SparseTensor(feats=feats, coords=coords)
|
||||
|
||||
def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
|
||||
new_cache = {}
|
||||
@ -853,10 +699,12 @@ class SparseTensor(VarLenTensor):
|
||||
|
||||
def __elemwise__(self, other: Union[torch.Tensor, VarLenTensor], op: callable) -> 'SparseTensor':
|
||||
if isinstance(other, torch.Tensor):
|
||||
# Try per-batch [B, C] -> per-voxel [N, C] broadcast. RuntimeError
|
||||
# fires for incompatible shapes; fall through and let op() handle.
|
||||
try:
|
||||
other = torch.broadcast_to(other, self.shape)
|
||||
other = other[self.batch_boardcast_map]
|
||||
except:
|
||||
except RuntimeError:
|
||||
pass
|
||||
if isinstance(other, VarLenTensor):
|
||||
other = other.feats
|
||||
@ -901,12 +749,6 @@ class SparseTensor(VarLenTensor):
|
||||
new_tensor.register_spatial_cache('layout', new_layout)
|
||||
return new_tensor
|
||||
|
||||
def clear_spatial_cache(self) -> None:
|
||||
"""
|
||||
Clear all spatial caches.
|
||||
"""
|
||||
self._spatial_cache = {}
|
||||
|
||||
def register_spatial_cache(self, key, value) -> None:
|
||||
"""
|
||||
Register a spatial cache.
|
||||
@ -961,7 +803,7 @@ def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
|
||||
|
||||
# allow operations.Linear inheritance
|
||||
class SparseLinear:
|
||||
def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None, operations=nn, *args, **kwargs):
|
||||
def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None, operations=ops, *args, **kwargs):
|
||||
class _SparseLinear(operations.Linear):
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
||||
@ -971,24 +813,6 @@ class SparseLinear:
|
||||
|
||||
return _SparseLinear(in_features, out_features, bias=bias, device=device, dtype=dtype, *args, **kwargs)
|
||||
|
||||
MIX_PRECISION_MODULES = (
|
||||
nn.Conv1d,
|
||||
nn.Conv2d,
|
||||
nn.Conv3d,
|
||||
nn.ConvTranspose1d,
|
||||
nn.ConvTranspose2d,
|
||||
nn.ConvTranspose3d,
|
||||
nn.Linear,
|
||||
SparseConv3d,
|
||||
SparseLinear,
|
||||
)
|
||||
|
||||
|
||||
def convert_module_to_f16(l):
|
||||
if isinstance(l, MIX_PRECISION_MODULES):
|
||||
for p in l.parameters():
|
||||
p.data = p.data.half()
|
||||
|
||||
class SparseUnetVaeDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -999,17 +823,13 @@ class SparseUnetVaeDecoder(nn.Module):
|
||||
block_type: List[str],
|
||||
up_block_type: List[str],
|
||||
block_args: List[Dict[str, Any]],
|
||||
use_fp16: bool = False,
|
||||
pred_subdiv: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.model_channels = model_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.use_fp16 = use_fp16
|
||||
self.pred_subdiv = pred_subdiv
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
self.low_vram = False
|
||||
|
||||
self.output_layer = SparseLinear(model_channels[-1], out_channels)
|
||||
self.from_latent = SparseLinear(latent_channels, model_channels[0])
|
||||
@ -1033,17 +853,9 @@ class SparseUnetVaeDecoder(nn.Module):
|
||||
**block_args[i],
|
||||
)
|
||||
)
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, x: SparseTensor, guide_subs: Optional[List[SparseTensor]] = None, return_subs: bool = False) -> SparseTensor:
|
||||
|
||||
dtype = next(self.from_latent.parameters()).dtype
|
||||
device = next(self.from_latent.parameters()).device
|
||||
x.feats = x.feats.to(dtype).to(device)
|
||||
h = self.from_latent(x)
|
||||
h = h.type(self.dtype)
|
||||
subs = []
|
||||
for i, res in enumerate(self.blocks):
|
||||
for j, block in enumerate(res):
|
||||
@ -1055,7 +867,6 @@ class SparseUnetVaeDecoder(nn.Module):
|
||||
h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None)
|
||||
else:
|
||||
h = block(h)
|
||||
h = h.type(x.feats.dtype)
|
||||
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
||||
h = self.output_layer(h)
|
||||
if return_subs:
|
||||
@ -1064,9 +875,7 @@ class SparseUnetVaeDecoder(nn.Module):
|
||||
return h
|
||||
|
||||
def upsample(self, x: SparseTensor, upsample_times: int) -> torch.Tensor:
|
||||
|
||||
h = self.from_latent(x)
|
||||
h = h.type(self.dtype)
|
||||
for i, res in enumerate(self.blocks):
|
||||
if i == upsample_times:
|
||||
return h.coords
|
||||
@ -1087,13 +896,9 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
|
||||
up_block_type: List[str],
|
||||
block_args: List[Dict[str, Any]],
|
||||
voxel_margin: float = 0.5,
|
||||
use_fp16: bool = False,
|
||||
):
|
||||
self.resolution = resolution
|
||||
self.voxel_margin = voxel_margin
|
||||
# cache for a TorchHashMap instance
|
||||
self._torch_hashmap_cache = None
|
||||
|
||||
super().__init__(
|
||||
7,
|
||||
model_channels,
|
||||
@ -1102,22 +907,11 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
|
||||
block_type,
|
||||
up_block_type,
|
||||
block_args,
|
||||
use_fp16,
|
||||
)
|
||||
|
||||
def set_resolution(self, resolution: int) -> None:
|
||||
self.resolution = resolution
|
||||
|
||||
def _build_or_get_hashmap(self, coords: torch.Tensor, grid_size: torch.Tensor):
|
||||
device = coords.device
|
||||
N = coords.shape[0]
|
||||
_, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
||||
flat_keys = coords[:, 0].long() * (H * D)
|
||||
flat_keys.add_(coords[:, 1].long() * D)
|
||||
flat_keys.add_(coords[:, 2].long())
|
||||
values = torch.arange(N, dtype=torch.int32, device=device)
|
||||
return TorchHashMap(flat_keys, values, 0xffffffff)
|
||||
|
||||
def forward(self, x: SparseTensor, gt_intersected: SparseTensor = None, **kwargs):
|
||||
decoded = super().forward(x, **kwargs)
|
||||
out_list = list(decoded) if isinstance(decoded, tuple) else [decoded]
|
||||
@ -1125,13 +919,11 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
|
||||
vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin)
|
||||
intersected = h.replace(h.feats[..., 3:6] > 0)
|
||||
quad_lerp = h.replace(F.softplus(h.feats[..., 6:7]))
|
||||
mesh = [Mesh(*flexible_dual_grid_to_mesh(
|
||||
mesh = [flexible_dual_grid_to_mesh(
|
||||
v.coords[:, 1:], v.feats, i.feats, q.feats,
|
||||
aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
|
||||
grid_size=self.resolution,
|
||||
train=False,
|
||||
hashmap_builder=self._build_or_get_hashmap,
|
||||
)) for v, i, q in zip(vertices, intersected, quad_lerp)]
|
||||
) for v, i, q in zip(vertices, intersected, quad_lerp)]
|
||||
out_list[0] = mesh
|
||||
return out_list[0] if len(out_list) == 1 else tuple(out_list)
|
||||
|
||||
@ -1140,11 +932,9 @@ def flexible_dual_grid_to_mesh(
|
||||
dual_vertices: torch.Tensor,
|
||||
intersected_flag: torch.Tensor,
|
||||
split_weight: Union[torch.Tensor, None],
|
||||
aabb: Union[list, tuple, np.ndarray, torch.Tensor],
|
||||
voxel_size: Union[float, list, tuple, np.ndarray, torch.Tensor] = None,
|
||||
grid_size: Union[int, list, tuple, np.ndarray, torch.Tensor] = None,
|
||||
train: bool = False,
|
||||
hashmap_builder=None, # optional callable for building/caching a TorchHashMap
|
||||
aabb: Union[list, tuple, torch.Tensor],
|
||||
voxel_size: Union[float, list, tuple, torch.Tensor] = None,
|
||||
grid_size: Union[int, list, tuple, torch.Tensor] = None,
|
||||
):
|
||||
|
||||
device = coords.device
|
||||
@ -1159,46 +949,28 @@ def flexible_dual_grid_to_mesh(
|
||||
flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
|
||||
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2") or flexible_dual_grid_to_mesh.quad_split_2.device != device:
|
||||
flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
|
||||
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train") or flexible_dual_grid_to_mesh.quad_split_train.device != device:
|
||||
flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=device, requires_grad=False)
|
||||
|
||||
# AABB
|
||||
if isinstance(aabb, (list, tuple)):
|
||||
aabb = np.array(aabb)
|
||||
if isinstance(aabb, np.ndarray):
|
||||
aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
|
||||
aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
|
||||
|
||||
# Voxel size
|
||||
if voxel_size is not None:
|
||||
if isinstance(voxel_size, float):
|
||||
voxel_size = [voxel_size, voxel_size, voxel_size]
|
||||
if isinstance(voxel_size, (list, tuple)):
|
||||
voxel_size = np.array(voxel_size)
|
||||
if isinstance(voxel_size, np.ndarray):
|
||||
voxel_size = torch.tensor(voxel_size, dtype=torch.float32, device=coords.device)
|
||||
voxel_size = [voxel_size] * 3
|
||||
voxel_size = torch.tensor(voxel_size, dtype=torch.float32, device=device)
|
||||
grid_size = ((aabb[1] - aabb[0]) / voxel_size).round().int()
|
||||
else:
|
||||
if isinstance(grid_size, int):
|
||||
grid_size = [grid_size, grid_size, grid_size]
|
||||
if isinstance(grid_size, (list, tuple)):
|
||||
grid_size = np.array(grid_size)
|
||||
if isinstance(grid_size, np.ndarray):
|
||||
grid_size = torch.tensor(grid_size, dtype=torch.int32, device=coords.device)
|
||||
grid_size = [grid_size] * 3
|
||||
grid_size = torch.tensor(grid_size, dtype=torch.int32, device=device)
|
||||
voxel_size = (aabb[1] - aabb[0]) / grid_size
|
||||
|
||||
# Extract mesh
|
||||
N = dual_vertices.shape[0]
|
||||
|
||||
if hashmap_builder is None:
|
||||
device = coords.device
|
||||
_, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
||||
flat_keys = coords[:, 0].long() * (H * D)
|
||||
flat_keys.add_(coords[:, 1].long() * D)
|
||||
flat_keys.add_(coords[:, 2].long())
|
||||
values = torch.arange(N, dtype=torch.long, device=device)
|
||||
torch_hashmap = TorchHashMap(flat_keys, values, 0xffffffff)
|
||||
else:
|
||||
torch_hashmap = hashmap_builder(coords, grid_size)
|
||||
_, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
||||
flat_keys = coords[:, 0].long() * (H * D)
|
||||
flat_keys.add_(coords[:, 1].long() * D)
|
||||
flat_keys.add_(coords[:, 2].long())
|
||||
values = torch.arange(N, dtype=torch.int32, device=device)
|
||||
torch_hashmap = TorchHashMap(flat_keys, values, 0xffffffff)
|
||||
|
||||
# Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3]
|
||||
n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,)
|
||||
@ -1253,55 +1025,34 @@ class ChannelLayerNorm32(LayerNorm32):
|
||||
return x
|
||||
|
||||
class UpsampleBlock3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
mode = "conv",
|
||||
):
|
||||
assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
if mode == "conv":
|
||||
self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
|
||||
elif mode == "nearest":
|
||||
assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
|
||||
self.conv = ops.Conv3d(in_channels, out_channels * 8, 3, padding=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self, "conv"):
|
||||
x = self.conv(x)
|
||||
return pixel_shuffle_3d(x, 2)
|
||||
else:
|
||||
return F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
|
||||
def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
|
||||
return ChannelLayerNorm32(*args, **kwargs)
|
||||
return pixel_shuffle_3d(self.conv(x), 2)
|
||||
|
||||
class ResBlock3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
norm_type = "layer",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
|
||||
self.norm1 = norm_layer(norm_type, channels)
|
||||
self.norm2 = norm_layer(norm_type, self.out_channels)
|
||||
self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
|
||||
self.conv2 = nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)
|
||||
self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
|
||||
self.norm1 = ChannelLayerNorm32(channels)
|
||||
self.norm2 = ChannelLayerNorm32(self.out_channels)
|
||||
self.conv1 = ops.Conv3d(channels, self.out_channels, 3, padding=1)
|
||||
self.conv2 = ops.Conv3d(self.out_channels, self.out_channels, 3, padding=1)
|
||||
self.skip_connection = ops.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h = self.norm1(x)
|
||||
h = F.silu(h)
|
||||
dtype = next(self.conv1.parameters()).dtype
|
||||
h = h.to(dtype)
|
||||
h = self.conv1(h)
|
||||
h = self.norm2(h)
|
||||
h = F.silu(h)
|
||||
@ -1318,8 +1069,6 @@ class SparseStructureDecoder(nn.Module):
|
||||
num_res_blocks: int,
|
||||
channels: List[int],
|
||||
num_res_blocks_middle: int = 2,
|
||||
norm_type = "layer",
|
||||
use_fp16: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
@ -1327,11 +1076,8 @@ class SparseStructureDecoder(nn.Module):
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.channels = channels
|
||||
self.num_res_blocks_middle = num_res_blocks_middle
|
||||
self.norm_type = norm_type
|
||||
self.use_fp16 = use_fp16
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
|
||||
self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
|
||||
self.input_layer = ops.Conv3d(latent_channels, channels[0], 3, padding=1)
|
||||
|
||||
self.middle_block = nn.Sequential(*[
|
||||
ResBlock3d(channels[0], channels[0])
|
||||
@ -1350,92 +1096,82 @@ class SparseStructureDecoder(nn.Module):
|
||||
)
|
||||
|
||||
self.out_layer = nn.Sequential(
|
||||
norm_layer(norm_type, channels[-1]),
|
||||
ChannelLayerNorm32(channels[-1]),
|
||||
nn.SiLU(),
|
||||
nn.Conv3d(channels[-1], out_channels, 3, padding=1)
|
||||
ops.Conv3d(channels[-1], out_channels, 3, padding=1)
|
||||
)
|
||||
|
||||
if use_fp16:
|
||||
self.convert_to_fp16()
|
||||
|
||||
def device(self) -> torch.device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
def convert_to_fp16(self) -> None:
|
||||
self.use_fp16 = True
|
||||
self.dtype = torch.float16
|
||||
self.blocks.apply(convert_module_to_f16)
|
||||
self.middle_block.apply(convert_module_to_f16)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dtype = next(self.input_layer.parameters()).dtype
|
||||
x = x.to(dtype)
|
||||
h = self.input_layer(x)
|
||||
|
||||
h = h.type(self.dtype)
|
||||
h = self.middle_block(h)
|
||||
for block in self.blocks:
|
||||
h = block(h)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = self.out_layer(h)
|
||||
return h
|
||||
|
||||
class Vae(nn.Module):
|
||||
def __init__(self, init_txt_model, init_txt_model_only, operations=None):
|
||||
|
||||
class ShapeVae(nn.Module):
|
||||
"""Decoder bundle from the Trellis2 shape checkpoint: structure decoder
|
||||
(32^3 SS latent → 64^3 dense occupancy) and shape decoder (sparse latent →
|
||||
mesh + per-stage subdivisions)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
operations = operations or torch.nn
|
||||
if init_txt_model or init_txt_model_only:
|
||||
self.txt_dec = SparseUnetVaeDecoder(
|
||||
out_channels=6,
|
||||
model_channels=[1024, 512, 256, 128, 64],
|
||||
latent_channels=32,
|
||||
num_blocks=[4, 16, 8, 4, 0],
|
||||
block_type=["SparseConvNeXtBlock3d"] * 5,
|
||||
up_block_type=["SparseResBlockC2S3d"] * 4,
|
||||
block_args=[{}, {}, {}, {}, {}],
|
||||
pred_subdiv=False
|
||||
)
|
||||
|
||||
if not init_txt_model_only:
|
||||
self.shape_dec = FlexiDualGridVaeDecoder(
|
||||
resolution=256,
|
||||
model_channels=[1024, 512, 256, 128, 64],
|
||||
latent_channels=32,
|
||||
num_blocks=[4, 16, 8, 4, 0],
|
||||
block_type=["SparseConvNeXtBlock3d"] * 5,
|
||||
up_block_type=["SparseResBlockC2S3d"] * 4,
|
||||
block_args=[{}, {}, {}, {}, {}],
|
||||
)
|
||||
|
||||
self.struct_dec = SparseStructureDecoder(
|
||||
out_channels=1,
|
||||
latent_channels=8,
|
||||
num_res_blocks=2,
|
||||
num_res_blocks_middle=2,
|
||||
channels=[512, 128, 32],
|
||||
)
|
||||
self.shape_dec = FlexiDualGridVaeDecoder(
|
||||
resolution=256,
|
||||
model_channels=[1024, 512, 256, 128, 64],
|
||||
latent_channels=32,
|
||||
num_blocks=[4, 16, 8, 4, 0],
|
||||
block_type=["SparseConvNeXtBlock3d"] * 5,
|
||||
up_block_type=["SparseResBlockC2S3d"] * 4,
|
||||
block_args=[{}, {}, {}, {}, {}],
|
||||
)
|
||||
self.struct_dec = SparseStructureDecoder(
|
||||
out_channels=1,
|
||||
latent_channels=8,
|
||||
num_res_blocks=2,
|
||||
num_res_blocks_middle=2,
|
||||
channels=[512, 128, 32],
|
||||
)
|
||||
self.register_buffer("resolution", torch.tensor(1024.0), persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_shape_slat(self, slat, resolution: int):
|
||||
def decode_structure(self, x: torch.Tensor) -> torch.Tensor:
|
||||
weight = self.struct_dec.input_layer.weight
|
||||
x = x.to(dtype=weight.dtype, device=weight.device)
|
||||
return self.struct_dec(x)
|
||||
|
||||
def decode_shape_slat(self, slat: 'SparseTensor', resolution: int):
|
||||
weight = self.shape_dec.from_latent.weight
|
||||
slat = slat.to(dtype=weight.dtype, device=weight.device)
|
||||
self.shape_dec.set_resolution(resolution)
|
||||
return self.shape_dec(slat, return_subs=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_tex_slat(self, slat, subs):
|
||||
if self.txt_dec is None:
|
||||
raise ValueError("Checkpoint doesn't include texture model")
|
||||
def upsample_shape(self, slat: 'SparseTensor', upsample_times: int) -> torch.Tensor:
|
||||
weight = self.shape_dec.from_latent.weight
|
||||
slat = slat.to(dtype=weight.dtype, device=weight.device)
|
||||
return self.shape_dec.upsample(slat, upsample_times)
|
||||
|
||||
|
||||
class TextureVae(nn.Module):
|
||||
"""Decoder bundle from the Trellis2 texture checkpoint: sparse 3D
|
||||
per-voxel color decoder, guided by subdivisions from a prior shape decode."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.txt_dec = SparseUnetVaeDecoder(
|
||||
out_channels=6,
|
||||
model_channels=[1024, 512, 256, 128, 64],
|
||||
latent_channels=32,
|
||||
num_blocks=[4, 16, 8, 4, 0],
|
||||
block_type=["SparseConvNeXtBlock3d"] * 5,
|
||||
up_block_type=["SparseResBlockC2S3d"] * 4,
|
||||
block_args=[{}, {}, {}, {}, {}],
|
||||
pred_subdiv=False,
|
||||
)
|
||||
self.register_buffer("resolution", torch.tensor(1024.0), persistent=False)
|
||||
|
||||
def decode_tex_slat(self, slat: 'SparseTensor', subs):
|
||||
weight = self.txt_dec.from_latent.weight
|
||||
slat = slat.to(dtype=weight.dtype, device=weight.device)
|
||||
return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5
|
||||
|
||||
# shouldn't be called (placeholder)
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
shape_slat: SparseTensor,
|
||||
tex_slat: SparseTensor,
|
||||
resolution: int,
|
||||
):
|
||||
meshes, subs = self.decode_shape_slat(shape_slat, resolution)
|
||||
tex_voxels = self.decode_tex_slat(tex_slat, subs)
|
||||
return tex_voxels
|
||||
|
||||
16
comfy/sd.py
16
comfy/sd.py
@ -529,18 +529,18 @@ class VAE:
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd or "txt_dec.blocks.3.4.conv2.weight" in sd: # trellis2 or trellis2 texture only
|
||||
init_txt_model = False
|
||||
init_txt_model_only = False
|
||||
if "shape_dec.blocks.1.16.to_subdiv.weight" not in sd:
|
||||
init_txt_model_only = True
|
||||
if "txt_dec.blocks.1.16.norm1.weight" in sd:
|
||||
init_txt_model = True
|
||||
elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd: # trellis2 shape vae (struct_dec + shape_dec)
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
# TODO
|
||||
self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model, init_txt_model_only= init_txt_model_only)
|
||||
self.first_stage_model = comfy.ldm.trellis2.vae.ShapeVae()
|
||||
elif "txt_dec.blocks.3.4.conv2.weight" in sd: # trellis2 texture vae
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
# TODO
|
||||
self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.first_stage_model = comfy.ldm.trellis2.vae.TextureVae()
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
|
||||
@ -205,8 +205,8 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
coords_list = [stage_tensor.coords for stage_tensor in stage_tensors]
|
||||
subs.append(SparseTensor.from_tensor_list(feats_list, coords_list))
|
||||
|
||||
face_list = [m.faces for m in mesh]
|
||||
vert_list = [m.vertices for m in mesh]
|
||||
vert_list = [v.float() for v, f in mesh]
|
||||
face_list = [f.int() for v, f in mesh]
|
||||
if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list):
|
||||
mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list))
|
||||
else:
|
||||
@ -286,12 +286,12 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
sample_tensor = samples["samples"]
|
||||
sample_tensor = sample_tensor[:, :8]
|
||||
batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||
decoder = vae.first_stage_model.struct_dec
|
||||
shape_vae = vae.first_stage_model
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
decoded_batches = []
|
||||
for start in range(0, sample_tensor.shape[0], batch_number):
|
||||
sample_chunk = sample_tensor[start:start + batch_number].to(load_device)
|
||||
decoded_batches.append(decoder(sample_chunk) > 0)
|
||||
decoded_batches.append(shape_vae.decode_structure(sample_chunk) > 0)
|
||||
decoded = torch.cat(decoded_batches, dim=0)
|
||||
current_res = decoded.shape[2]
|
||||
|
||||
@ -349,10 +349,9 @@ class Trellis2UpsampleStage(IO.ComfyNode):
|
||||
prepare_trellis_vae_for_decode(vae, shape_latent["samples"].shape)
|
||||
|
||||
coord_counts = shape_latent.get("coord_counts")
|
||||
decoder = vae.first_stage_model.shape_dec
|
||||
shape_vae = vae.first_stage_model
|
||||
lr_resolution = 512
|
||||
target_resolution = int(target_resolution)
|
||||
decoder_dtype = next(decoder.parameters()).dtype
|
||||
|
||||
# Decode each sample's HR coords, then search for the largest hr_resolution
|
||||
# that fits under max_tokens across all samples.
|
||||
@ -361,8 +360,7 @@ class Trellis2UpsampleStage(IO.ComfyNode):
|
||||
shape_latent["samples"], shape_latent["coords"], coord_counts,
|
||||
)
|
||||
slat = shape_norm(feats.to(device), coords_512.to(device))
|
||||
slat.feats = slat.feats.to(decoder_dtype)
|
||||
sample_hr_coords = [decoder.upsample(slat, upsample_times=4)]
|
||||
sample_hr_coords = [shape_vae.upsample_shape(slat, upsample_times=4)]
|
||||
else:
|
||||
items = split_batched_sparse_latent(
|
||||
shape_latent["samples"], shape_latent["coords"], coord_counts,
|
||||
@ -372,8 +370,7 @@ class Trellis2UpsampleStage(IO.ComfyNode):
|
||||
coords_i = coords_i.to(device).clone()
|
||||
coords_i[:, 0] = 0
|
||||
slat_i = shape_norm(feats_i.to(device), coords_i)
|
||||
slat_i.feats = slat_i.feats.to(decoder_dtype)
|
||||
sample_hr_coords.append(decoder.upsample(slat_i, upsample_times=4))
|
||||
sample_hr_coords.append(shape_vae.upsample_shape(slat_i, upsample_times=4))
|
||||
|
||||
# Resolution search — cache the final iteration's quantized unique tensors
|
||||
# so we don't recompute .unique() per sample after picking hr_resolution.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user