model init working

This commit is contained in:
Yousef Rafat 2026-02-03 21:10:20 +02:00
parent f76e3a11b5
commit d6573fd26d
6 changed files with 639 additions and 12 deletions

View File

@ -746,6 +746,8 @@ class Hunyuan3Dv2_1(LatentFormat):
latent_channels = 64 latent_channels = 64
latent_dimensions = 1 latent_dimensions = 1
class Trellis2(LatentFormat): # TODO
latent_channels = 32
class Hunyuan3Dv2mini(LatentFormat): class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64 latent_channels = 64
latent_dimensions = 1 latent_dimensions = 1

View File

@ -1,8 +1,192 @@
# will contain every cuda -> pytorch operation # will contain every cuda -> pytorch operation
import math
import torch import torch
from typing import Dict from typing import Dict, Callable
NO_TRITION = False
try:
import triton
import triton.language as tl
heuristics = {
'valid_kernel': lambda args: args['valid_kernel'](args['B1']),
'valid_kernel_seg': lambda args: args['valid_kernel_seg'](args['B1']),
}
#@triton_autotune(
# configs=config.autotune_config,
# key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'],
#)
@triton.heuristics(heuristics)
@triton.jit
def sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel(
input,
weight,
bias,
neighbor,
sorted_idx,
output,
# Tensor dimensions
N, LOGN, Ci, Co, V: tl.constexpr,
# Meta-parameters
B1: tl.constexpr, # Block size for N dimension
B2: tl.constexpr, # Block size for Co dimension
BK: tl.constexpr, # Block size for K dimension (V * Ci)
allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
# Huristic parameters
valid_kernel,
valid_kernel_seg,
):
block_id = tl.program_id(axis=0)
block_dim_co = tl.cdiv(Co, B2)
block_id_co = block_id % block_dim_co
block_id_n = block_id // block_dim_co
# Create pointers for submatrices of A and B.
num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension
valid_kernel_start = tl.load(valid_kernel_seg + block_id_n)
valid_kernel_seglen = tl.load(valid_kernel_seg + block_id_n + 1) - valid_kernel_start
offset_n = block_id_n * B1 + tl.arange(0, B1)
n_mask = offset_n < N
offset_sorted_n = tl.load(sorted_idx + offset_n, mask=n_mask, other=0) # (B1,)
offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,)
offset_k = tl.arange(0, BK) # (BK,)
# Create a block of the output matrix C.
accumulator = tl.zeros((B1, B2), dtype=tl.float32)
# Iterate along V*Ci dimension.
for k in range(num_k * valid_kernel_seglen):
v = k // num_k
bk = k % num_k
v = tl.load(valid_kernel + valid_kernel_start + v)
# Calculate pointers to input matrix.
neighbor_offset_n = tl.load(neighbor + offset_sorted_n * V + v) # (B1,)
input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK)
# Calculate pointers to weight matrix.
weight_ptr = weight + v * Ci + bk * BK + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2)
# Load the next block of input and weight.
neigh_mask = neighbor_offset_n != 0xffffffff
k_mask = offset_k < Ci - bk * BK
input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
# Accumulate along the K dimension.
accumulator = tl.dot(input_block, weight_block, accumulator,
input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
c = accumulator.to(input.type.element_ty)
# add bias
if bias is not None:
bias_block = tl.load(bias + offset_co)
c += bias_block[None, :]
# Write back the block of the output matrix with masks.
out_offset_n = offset_sorted_n
out_offset_co = block_id_co * B2 + tl.arange(0, B2)
out_ptr = output + (out_offset_n[:, None] * Co + out_offset_co[None, :])
out_mask = n_mask[:, None] & (out_offset_co[None, :] < Co)
tl.store(out_ptr, c, mask=out_mask)
def sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
neighbor: torch.Tensor,
sorted_idx: torch.Tensor,
valid_kernel: Callable[[int], torch.Tensor],
valid_kernel_seg: Callable[[int], torch.Tensor],
) -> torch.Tensor:
N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
LOGN = int(math.log2(N))
output = torch.empty((N, Co), device=input.device, dtype=input.dtype)
grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),)
sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid](
input, weight, bias, neighbor, sorted_idx, output,
N, LOGN, Ci, Co, V, #
valid_kernel=valid_kernel,
valid_kernel_seg=valid_kernel_seg,
allow_tf32=torch.cuda.is_tf32_supported(),
)
return output
except:
NO_TRITION = True
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
# offsets in same order as CUDA kernel
offsets = []
for vx in range(Kw):
for vy in range(Kh):
for vz in range(Kd):
offsets.append((
vx * Dw,
vy * Dh,
vz * Dd
))
return torch.tensor(offsets, device=device)
def build_submanifold_neighbor_map(
hashmap,
coords: torch.Tensor,
W, H, D,
Kw, Kh, Kd,
Dw, Dh, Dd,
):
device = coords.device
M = coords.shape[0]
V = Kw * Kh * Kd
half_V = V // 2 + 1
INVALID = hashmap.default_value
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long)
b = coords[:, 0]
x = coords[:, 1]
y = coords[:, 2]
z = coords[:, 3]
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device)
ox = x[:, None] - (Kw // 2) * Dw
oy = y[:, None] - (Kh // 2) * Dh
oz = z[:, None] - (Kd // 2) * Dd
for v in range(half_V):
if v == half_V - 1:
neighbor[:, v] = torch.arange(M, device=device)
continue
dx, dy, dz = offsets[v]
kx = ox[:, v] + dx
ky = oy[:, v] + dy
kz = oz[:, v] + dz
valid = (
(kx >= 0) & (kx < W) &
(ky >= 0) & (ky < H) &
(kz >= 0) & (kz < D)
)
flat = (
b * (W * H * D) +
kx * (H * D) +
ky * D +
kz
)
flat = flat[valid]
idx = torch.nonzero(valid, as_tuple=False).squeeze(1)
found = hashmap.lookup_flat(flat)
neighbor[idx, v] = found
# symmetric write
valid_found = found != INVALID
neighbor[found[valid_found], V - 1 - v] = idx[valid_found]
return neighbor
class TorchHashMap: class TorchHashMap:
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
@ -22,6 +206,207 @@ class TorchHashMap:
out[found] = self.sorted_vals[idx[found]] out[found] = self.sorted_vals[idx[found]]
return out return out
UINT32_SENTINEL = 0xFFFFFFFF
def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map):
device = neighbor_map.device
N, V = neighbor_map.shape
neigh = neighbor_map.to(torch.long)
sentinel = torch.tensor(UINT32_SENTINEL, dtype=torch.long, device=device)
neigh_map_T = neigh.t().reshape(-1)
neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32)
mask = (neigh != sentinel).to(torch.long)
powers = (1 << torch.arange(V, dtype=torch.long, device=device))
gray_long = (mask * powers).sum(dim=1)
gray_code = gray_long.to(torch.int32)
binary_long = gray_long.clone()
for v in range(1, V):
binary_long ^= (gray_long >> v)
binary_code = binary_long.to(torch.int32)
sorted_idx = torch.argsort(binary_code)
prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T.to(torch.int32), dim=0) # (V*N,)
total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0
if total_valid_signal > 0:
valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0]
to = (prefix_sum_neighbor_mask[pos] - 1).to(torch.long)
valid_signal_i[to] = (pos % N).to(torch.long)
valid_signal_o[to] = neigh_map_T[pos].to(torch.long)
else:
valid_signal_i = torch.empty((0,), dtype=torch.long, device=device)
valid_signal_o = torch.empty((0,), dtype=torch.long, device=device)
seg = torch.empty((V + 1,), dtype=torch.long, device=device)
seg[0] = 0
if V > 0:
idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1
seg[1:] = prefix_sum_neighbor_mask[idxs].to(torch.long)
else:
pass
return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg
def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor:
x = x.to(torch.int64)
m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device)
m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device)
m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device)
h01 = torch.tensor(0x0101010101010101, dtype=torch.int64, device=x.device)
x = x - ((x >> 1) & m1)
x = (x & m2) + ((x >> 2) & m2)
x = (x + (x >> 4)) & m4
x = (x * h01) >> 56
return x.to(torch.int32)
def neighbor_map_post_process_for_masked_implicit_gemm_2(
gray_code: torch.Tensor, # [N], int32-like (non-negative)
sorted_idx: torch.Tensor, # [N], long (indexing into gray_code)
block_size: int
):
device = gray_code.device
N = gray_code.numel()
# num of blocks (same as CUDA)
num_blocks = (N + block_size - 1) // block_size
# Ensure dtypes
gray_long = gray_code.to(torch.int64) # safer to OR in 64-bit then cast
sorted_idx = sorted_idx.to(torch.long)
# 1) Group gray_code by blocks and compute OR across each block
# pad the last block with zeros if necessary so we can reshape
pad = num_blocks * block_size - N
if pad > 0:
pad_vals = torch.zeros((pad,), dtype=torch.int64, device=device)
gray_padded = torch.cat([gray_long[sorted_idx], pad_vals], dim=0)
else:
gray_padded = gray_long[sorted_idx]
# reshape to (num_blocks, block_size) and compute bitwise_or across dim=1
gray_blocks = gray_padded.view(num_blocks, block_size) # each row = block entries
# reduce with bitwise_or
reduced_code = gray_blocks[:, 0].clone()
for i in range(1, block_size):
reduced_code |= gray_blocks[:, i]
reduced_code = reduced_code.to(torch.int32) # match CUDA int32
# 2) compute seglen (popcount per reduced_code) and seg (prefix sum)
seglen_counts = _popcount_int32_tensor(reduced_code.to(torch.int64)).to(torch.int32) # [num_blocks]
# seg: length num_blocks+1, seg[0]=0, seg[i+1]=cumsum(seglen_counts) up to i
seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device)
seg[0] = 0
if num_blocks > 0:
seg[1:] = torch.cumsum(seglen_counts, dim=0)
total = int(seg[-1].item())
# 3) scatter — produce valid_kernel_idx as concatenated ascending set-bit positions for each reduced_code row
if total == 0:
valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device)
return valid_kernel_idx, seg
max_val = int(reduced_code.max().item())
V = max_val.bit_length() if max_val > 0 else 0
# If you know V externally, pass it instead or set here explicitly.
if V == 0:
# no bits set anywhere
valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device)
return valid_kernel_idx, seg
# build mask of shape (num_blocks, V): True where bit is set
bit_pos = torch.arange(0, V, dtype=torch.int64, device=device) # [V]
# shifted = reduced_code[:, None] >> bit_pos[None, :]
shifted = reduced_code.to(torch.int64).unsqueeze(1) >> bit_pos.unsqueeze(0)
bits = (shifted & 1).to(torch.bool) # (num_blocks, V)
positions = bit_pos.unsqueeze(0).expand(num_blocks, V)
valid_positions = positions[bits]
valid_kernel_idx = valid_positions.to(torch.int32).contiguous()
return valid_kernel_idx, seg
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
if len(shape) == 5:
N, C, W, H, D = shape
else:
W, H, D = shape
Co, Kw, Kh, Kd, Ci = weight.shape
b_stride = W * H * D
x_stride = H * D
y_stride = D
z_stride = 1
flat_keys = (coords[:, 0].long() * b_stride +
coords[:, 1].long() * x_stride +
coords[:, 2].long() * y_stride +
coords[:, 3].long() * z_stride)
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=coords.device)
hashmap = TorchHashMap(flat_keys, vals, 0xFFFFFFFF)
if neighbor_cache is None:
neighbor = build_submanifold_neighbor_map(
hashmap, coords, W, H, D, Kw, Kh, Kd,
dilation[0], dilation[1], dilation[2]
)
else:
neighbor = neighbor_cache
block_size = 128
gray_code, sorted_idx, valid_signal_i, valid_signal_o, valid_signal_seg = \
neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor)
valid_kernel, valid_kernel_seg = \
neighbor_map_post_process_for_masked_implicit_gemm_2(gray_code, sorted_idx, block_size)
valid_kernel_fn = lambda b_size: valid_kernel
valid_kernel_seg_fn = lambda b_size: valid_kernel_seg
weight_flat = weight.contiguous().view(Co, -1, Ci)
out = sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk(
feats,
weight_flat,
bias,
neighbor,
sorted_idx,
valid_kernel_fn,
valid_kernel_seg_fn
)
return out, neighbor
class Voxel: class Voxel:
def __init__( def __init__(
self, self,

View File

@ -408,7 +408,7 @@ class SLatFlowModel(nn.Module):
self.qk_rms_norm_cross = qk_rms_norm_cross self.qk_rms_norm_cross = qk_rms_norm_cross
self.dtype = dtype self.dtype = dtype
self.t_embedder = TimestepEmbedder(model_channels) self.t_embedder = TimestepEmbedder(model_channels, device=device, dtype=dtype, operations=operations)
if share_mod: if share_mod:
self.adaLN_modulation = nn.Sequential( self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.SiLU(),
@ -485,15 +485,25 @@ class Trellis2(nn.Module):
qk_rms_norm = True, qk_rms_norm = True,
qk_rms_norm_cross = True, qk_rms_norm_cross = True,
dtype=None, device=None, operations=None): dtype=None, device=None, operations=None):
super().__init__()
args = { args = {
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
} }
# TODO: update the names/checkpoints # TODO: update the names/checkpoints
self.img2shape = SLatFlowModel(resolution, in_channels=in_channels, *args) self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
self.shape2txt = SLatFlowModel(resolution, in_channels=in_channels*2, *args) self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
self.shape_generation = True self.shape_generation = True
def forward(self, x, timestep, context): def forward(self, x, timestep, context, **kwargs):
pass # TODO add mode
mode = kwargs.get("mode", "shape_generation")
mode = "texture_generation" if mode == 1 else "shape_generation"
if mode == "shape_generation":
out = self.img2shape(x, timestep, context)
if mode == "texture_generation":
out = self.shape2txt(x, timestep, context)
return out

View File

@ -1,3 +1,4 @@
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import List, Any, Dict, Optional, overload, Union, Tuple from typing import List, Any, Dict, Optional, overload, Union, Tuple
@ -5,12 +6,219 @@ from fractions import Fraction
import torch.nn.functional as F import torch.nn.functional as F
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
from cumesh import TorchHashMap, Mesh, MeshWithVoxel from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
class SparseConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
super(SparseConv3d, self).__init__()
sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key)
def forward(self, x):
return sparse_conv3d_forward(self, x)
def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size, ) * 3
self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, ) * 3
self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation, ) * 3
self.weight = nn.Parameter(torch.empty((out_channels, in_channels, *self.kernel_size)))
if bias:
self.bias = nn.Parameter(torch.empty(out_channels))
else:
self.register_parameter("bias", None)
if self.bias is not None:
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
torch.nn.init.uniform_(self.bias, -bound, bound)
# Permute weight (Co, Ci, Kd, Kh, Kw) -> (Co, Kd, Kh, Kw, Ci)
self.weight = nn.Parameter(self.weight.permute(0, 2, 3, 4, 1).contiguous())
def sparse_conv3d_forward(self, x):
# check if neighbor map is already computed
Co, Kd, Kh, Kw, Ci = self.weight.shape
neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}'
neighbor_cache = x.get_spatial_cache(neighbor_cache_key)
out, neighbor_cache_ = sparse_submanifold_conv3d(
x.feats,
x.coords,
torch.Size([*x.shape, *x.spatial_shape]),
self.weight,
self.bias,
neighbor_cache,
self.dilation
)
if neighbor_cache is None:
x.register_spatial_cache(neighbor_cache_key, neighbor_cache_)
out = x.replace(out)
return out
class LayerNorm32(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dtype = x.dtype
x = x.to(torch.float32)
o = super().forward(x)
return o.to(x_dtype)
class SparseConvNeXtBlock3d(nn.Module):
def __init__(
self,
channels: int,
mlp_ratio: float = 4.0,
use_checkpoint: bool = False,
):
super().__init__()
self.channels = channels
self.use_checkpoint = use_checkpoint
self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.conv = SparseConv3d(channels, channels, 3)
self.mlp = nn.Sequential(
nn.Linear(channels, int(channels * mlp_ratio)),
nn.SiLU(),
nn.Linear(int(channels * mlp_ratio), channels),
)
def _forward(self, x):
h = self.conv(x)
h = h.replace(self.norm(h.feats))
h = h.replace(self.mlp(h.feats))
return h + x
def forward(self, x):
return self._forward(x)
class SparseSpatial2Channel(nn.Module):
def __init__(self, factor: int = 2):
super(SparseSpatial2Channel, self).__init__()
self.factor = factor
def forward(self, x):
DIM = x.coords.shape[-1] - 1
cache = x.get_spatial_cache(f'spatial2channel_{self.factor}')
if cache is None:
coord = list(x.coords.unbind(dim=-1))
for i in range(DIM):
coord[i+1] = coord[i+1] // self.factor
subidx = x.coords[:, 1:] % self.factor
subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)])
MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape]
OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
code = sum([c * o for c, o in zip(coord, OFFSET)])
code, idx = code.unique(return_inverse=True)
new_coords = torch.stack(
[code // OFFSET[0]] +
[(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
dim=-1
)
else:
new_coords, idx, subidx = cache
new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype)
new_feats[idx * self.factor ** DIM + subidx] = x.feats
out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM]))
out._scale = tuple([s * self.factor for s in x._scale])
out._spatial_cache = x._spatial_cache
if cache is None:
x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx))
out.register_spatial_cache(f'channel2spatial_{self.factor}', (x.coords, idx, subidx))
out.register_spatial_cache('shape', torch.Size(MAX))
return out
class SparseChannel2Spatial(nn.Module):
def __init__(self, factor: int = 2):
super(SparseChannel2Spatial, self).__init__()
self.factor = factor
def forward(self, x, subdivision = None):
DIM = x.coords.shape[-1] - 1
cache = x.get_spatial_cache(f'channel2spatial_{self.factor}')
if cache is None:
if subdivision is None:
raise ValueError('Cache not found. Provide subdivision tensor or pair SparseChannel2Spatial with SparseSpatial2Channel.')
else:
sub = subdivision.feats # [N, self.factor ** DIM]
N_leaf = sub.sum(dim=-1) # [N]
subidx = sub.nonzero()[:, -1]
new_coords = x.coords.clone().detach()
new_coords[:, 1:] *= self.factor
new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0])
for i in range(DIM):
new_coords[:, i+1] += subidx // self.factor ** i % self.factor
idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0])
else:
new_coords, idx, subidx = cache
x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1)
new_feats = x_feats[idx * self.factor ** DIM + subidx]
out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM]))
out._scale = tuple([s / self.factor for s in x._scale])
if cache is not None: # only keep cache when subdiv following it
out._spatial_cache = x._spatial_cache
return out
class SparseResBlockC2S3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
use_checkpoint: bool = False,
pred_subdiv: bool = True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_checkpoint = use_checkpoint
self.pred_subdiv = pred_subdiv
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
self.conv1 = SparseConv3d(channels, self.out_channels * 8, 3)
self.conv2 = SparseConv3d(self.out_channels, self.out_channels, 3)
self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1))
if pred_subdiv:
self.to_subdiv = SparseLinear(channels, 8)
self.updown = SparseChannel2Spatial(2)
def _forward(self, x, subdiv = None):
if self.pred_subdiv:
subdiv = self.to_subdiv(x)
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
h = self.conv1(h)
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
h = self.updown(h, subdiv_binarized)
x = self.updown(x, subdiv_binarized)
h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x)
if self.pred_subdiv:
return h, subdiv
else:
return h
# TODO: determine which conv they actually use
@dataclass @dataclass
class config: class config:
CONV = "none" CONV = "flexgemm"
FLEX_GEMM_HASHMAP_RATIO = 2.0
# TODO post processing # TODO post processing
def simplify(self, target_num_faces: int, verbose: bool=False, options: dict={}): def simplify(self, target_num_faces: int, verbose: bool=False, options: dict={}):
@ -1131,6 +1339,7 @@ def flexible_dual_grid_to_mesh(
class Vae(nn.Module): class Vae(nn.Module):
def __init__(self, config, operations=None): def __init__(self, config, operations=None):
super().__init__()
operations = operations or torch.nn operations = operations or torch.nn
self.txt_dec = SparseUnetVaeDecoder( self.txt_dec = SparseUnetVaeDecoder(
@ -1139,7 +1348,8 @@ class Vae(nn.Module):
latent_channels=32, latent_channels=32,
num_blocks=[4, 16, 8, 4, 0], num_blocks=[4, 16, 8, 4, 0],
block_type=["SparseConvNeXtBlock3d"] * 5, block_type=["SparseConvNeXtBlock3d"] * 5,
up_block_type=["SparseResBlockS2C3d"] * 4, up_block_type=["SparseResBlockC2S3d"] * 4,
block_args=[{}, {}, {}, {}, {}],
pred_subdiv=False pred_subdiv=False
) )
@ -1149,7 +1359,8 @@ class Vae(nn.Module):
latent_channels=32, latent_channels=32,
num_blocks=[4, 16, 8, 4, 0], num_blocks=[4, 16, 8, 4, 0],
block_type=["SparseConvNeXtBlock3d"] * 5, block_type=["SparseConvNeXtBlock3d"] * 5,
up_block_type=["SparseResBlockS2C3d"] * 4, up_block_type=["SparseResBlockC2S3d"] * 4,
block_args=[{}, {}, {}, {}, {}],
) )
def decode_shape_slat(self, slat, resolution: int): def decode_shape_slat(self, slat, resolution: int):

View File

@ -50,6 +50,7 @@ import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model import comfy.ldm.anima.model
import comfy.ldm.trellis2.model
import comfy.model_management import comfy.model_management
import comfy.patcher_extension import comfy.patcher_extension
@ -1455,6 +1456,13 @@ class WAN22(WAN21):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image return latent_image
class Trellis2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None, unet_model=comfy.ldm.trellis2.model.Trellis2):
super().__init__(model_config, model_type, device, unet_model)
def extra_conds(self, **kwargs):
return super().extra_conds(**kwargs)
class Hunyuan3Dv2(BaseModel): class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)

View File

@ -1242,6 +1242,17 @@ class WAN22_T2V(WAN21_T2V):
out = model_base.WAN22(self, image_to_video=True, device=device) out = model_base.WAN22(self, image_to_video=True, device=device)
return out return out
class Trellis2(supported_models_base.BASE):
unet_config = {
"image_model": "trellis2"
}
latent_format = latent_formats.Trellis2
vae_key_prefix = ["vae."]
def get_model(self, state_dict, prefix="", device=None):
return model_base.Trellis2(self, device=device)
class Hunyuan3Dv2(supported_models_base.BASE): class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "hunyuan3d2", "image_model": "hunyuan3d2",
@ -1596,6 +1607,6 @@ class Kandinsky5Image(Kandinsky5):
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, Trellis2]
models += [SVD_img2vid] models += [SVD_img2vid]