model init working

This commit is contained in:
Yousef Rafat 2026-02-03 21:10:20 +02:00
parent f76e3a11b5
commit d6573fd26d
6 changed files with 639 additions and 12 deletions

View File

@ -746,6 +746,8 @@ class Hunyuan3Dv2_1(LatentFormat):
latent_channels = 64
latent_dimensions = 1
class Trellis2(LatentFormat): # TODO
latent_channels = 32
class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@ -1,8 +1,192 @@
# will contain every cuda -> pytorch operation
import math
import torch
from typing import Dict
from typing import Dict, Callable
NO_TRITION = False
try:
import triton
import triton.language as tl
heuristics = {
'valid_kernel': lambda args: args['valid_kernel'](args['B1']),
'valid_kernel_seg': lambda args: args['valid_kernel_seg'](args['B1']),
}
#@triton_autotune(
# configs=config.autotune_config,
# key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'],
#)
@triton.heuristics(heuristics)
@triton.jit
def sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel(
input,
weight,
bias,
neighbor,
sorted_idx,
output,
# Tensor dimensions
N, LOGN, Ci, Co, V: tl.constexpr,
# Meta-parameters
B1: tl.constexpr, # Block size for N dimension
B2: tl.constexpr, # Block size for Co dimension
BK: tl.constexpr, # Block size for K dimension (V * Ci)
allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
# Huristic parameters
valid_kernel,
valid_kernel_seg,
):
block_id = tl.program_id(axis=0)
block_dim_co = tl.cdiv(Co, B2)
block_id_co = block_id % block_dim_co
block_id_n = block_id // block_dim_co
# Create pointers for submatrices of A and B.
num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension
valid_kernel_start = tl.load(valid_kernel_seg + block_id_n)
valid_kernel_seglen = tl.load(valid_kernel_seg + block_id_n + 1) - valid_kernel_start
offset_n = block_id_n * B1 + tl.arange(0, B1)
n_mask = offset_n < N
offset_sorted_n = tl.load(sorted_idx + offset_n, mask=n_mask, other=0) # (B1,)
offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,)
offset_k = tl.arange(0, BK) # (BK,)
# Create a block of the output matrix C.
accumulator = tl.zeros((B1, B2), dtype=tl.float32)
# Iterate along V*Ci dimension.
for k in range(num_k * valid_kernel_seglen):
v = k // num_k
bk = k % num_k
v = tl.load(valid_kernel + valid_kernel_start + v)
# Calculate pointers to input matrix.
neighbor_offset_n = tl.load(neighbor + offset_sorted_n * V + v) # (B1,)
input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK)
# Calculate pointers to weight matrix.
weight_ptr = weight + v * Ci + bk * BK + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2)
# Load the next block of input and weight.
neigh_mask = neighbor_offset_n != 0xffffffff
k_mask = offset_k < Ci - bk * BK
input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
# Accumulate along the K dimension.
accumulator = tl.dot(input_block, weight_block, accumulator,
input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
c = accumulator.to(input.type.element_ty)
# add bias
if bias is not None:
bias_block = tl.load(bias + offset_co)
c += bias_block[None, :]
# Write back the block of the output matrix with masks.
out_offset_n = offset_sorted_n
out_offset_co = block_id_co * B2 + tl.arange(0, B2)
out_ptr = output + (out_offset_n[:, None] * Co + out_offset_co[None, :])
out_mask = n_mask[:, None] & (out_offset_co[None, :] < Co)
tl.store(out_ptr, c, mask=out_mask)
def sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
neighbor: torch.Tensor,
sorted_idx: torch.Tensor,
valid_kernel: Callable[[int], torch.Tensor],
valid_kernel_seg: Callable[[int], torch.Tensor],
) -> torch.Tensor:
N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
LOGN = int(math.log2(N))
output = torch.empty((N, Co), device=input.device, dtype=input.dtype)
grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),)
sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid](
input, weight, bias, neighbor, sorted_idx, output,
N, LOGN, Ci, Co, V, #
valid_kernel=valid_kernel,
valid_kernel_seg=valid_kernel_seg,
allow_tf32=torch.cuda.is_tf32_supported(),
)
return output
except:
NO_TRITION = True
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
# offsets in same order as CUDA kernel
offsets = []
for vx in range(Kw):
for vy in range(Kh):
for vz in range(Kd):
offsets.append((
vx * Dw,
vy * Dh,
vz * Dd
))
return torch.tensor(offsets, device=device)
def build_submanifold_neighbor_map(
hashmap,
coords: torch.Tensor,
W, H, D,
Kw, Kh, Kd,
Dw, Dh, Dd,
):
device = coords.device
M = coords.shape[0]
V = Kw * Kh * Kd
half_V = V // 2 + 1
INVALID = hashmap.default_value
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long)
b = coords[:, 0]
x = coords[:, 1]
y = coords[:, 2]
z = coords[:, 3]
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device)
ox = x[:, None] - (Kw // 2) * Dw
oy = y[:, None] - (Kh // 2) * Dh
oz = z[:, None] - (Kd // 2) * Dd
for v in range(half_V):
if v == half_V - 1:
neighbor[:, v] = torch.arange(M, device=device)
continue
dx, dy, dz = offsets[v]
kx = ox[:, v] + dx
ky = oy[:, v] + dy
kz = oz[:, v] + dz
valid = (
(kx >= 0) & (kx < W) &
(ky >= 0) & (ky < H) &
(kz >= 0) & (kz < D)
)
flat = (
b * (W * H * D) +
kx * (H * D) +
ky * D +
kz
)
flat = flat[valid]
idx = torch.nonzero(valid, as_tuple=False).squeeze(1)
found = hashmap.lookup_flat(flat)
neighbor[idx, v] = found
# symmetric write
valid_found = found != INVALID
neighbor[found[valid_found], V - 1 - v] = idx[valid_found]
return neighbor
class TorchHashMap:
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
@ -22,6 +206,207 @@ class TorchHashMap:
out[found] = self.sorted_vals[idx[found]]
return out
UINT32_SENTINEL = 0xFFFFFFFF
def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map):
device = neighbor_map.device
N, V = neighbor_map.shape
neigh = neighbor_map.to(torch.long)
sentinel = torch.tensor(UINT32_SENTINEL, dtype=torch.long, device=device)
neigh_map_T = neigh.t().reshape(-1)
neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32)
mask = (neigh != sentinel).to(torch.long)
powers = (1 << torch.arange(V, dtype=torch.long, device=device))
gray_long = (mask * powers).sum(dim=1)
gray_code = gray_long.to(torch.int32)
binary_long = gray_long.clone()
for v in range(1, V):
binary_long ^= (gray_long >> v)
binary_code = binary_long.to(torch.int32)
sorted_idx = torch.argsort(binary_code)
prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T.to(torch.int32), dim=0) # (V*N,)
total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0
if total_valid_signal > 0:
valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0]
to = (prefix_sum_neighbor_mask[pos] - 1).to(torch.long)
valid_signal_i[to] = (pos % N).to(torch.long)
valid_signal_o[to] = neigh_map_T[pos].to(torch.long)
else:
valid_signal_i = torch.empty((0,), dtype=torch.long, device=device)
valid_signal_o = torch.empty((0,), dtype=torch.long, device=device)
seg = torch.empty((V + 1,), dtype=torch.long, device=device)
seg[0] = 0
if V > 0:
idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1
seg[1:] = prefix_sum_neighbor_mask[idxs].to(torch.long)
else:
pass
return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg
def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor:
x = x.to(torch.int64)
m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device)
m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device)
m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device)
h01 = torch.tensor(0x0101010101010101, dtype=torch.int64, device=x.device)
x = x - ((x >> 1) & m1)
x = (x & m2) + ((x >> 2) & m2)
x = (x + (x >> 4)) & m4
x = (x * h01) >> 56
return x.to(torch.int32)
def neighbor_map_post_process_for_masked_implicit_gemm_2(
gray_code: torch.Tensor, # [N], int32-like (non-negative)
sorted_idx: torch.Tensor, # [N], long (indexing into gray_code)
block_size: int
):
device = gray_code.device
N = gray_code.numel()
# num of blocks (same as CUDA)
num_blocks = (N + block_size - 1) // block_size
# Ensure dtypes
gray_long = gray_code.to(torch.int64) # safer to OR in 64-bit then cast
sorted_idx = sorted_idx.to(torch.long)
# 1) Group gray_code by blocks and compute OR across each block
# pad the last block with zeros if necessary so we can reshape
pad = num_blocks * block_size - N
if pad > 0:
pad_vals = torch.zeros((pad,), dtype=torch.int64, device=device)
gray_padded = torch.cat([gray_long[sorted_idx], pad_vals], dim=0)
else:
gray_padded = gray_long[sorted_idx]
# reshape to (num_blocks, block_size) and compute bitwise_or across dim=1
gray_blocks = gray_padded.view(num_blocks, block_size) # each row = block entries
# reduce with bitwise_or
reduced_code = gray_blocks[:, 0].clone()
for i in range(1, block_size):
reduced_code |= gray_blocks[:, i]
reduced_code = reduced_code.to(torch.int32) # match CUDA int32
# 2) compute seglen (popcount per reduced_code) and seg (prefix sum)
seglen_counts = _popcount_int32_tensor(reduced_code.to(torch.int64)).to(torch.int32) # [num_blocks]
# seg: length num_blocks+1, seg[0]=0, seg[i+1]=cumsum(seglen_counts) up to i
seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device)
seg[0] = 0
if num_blocks > 0:
seg[1:] = torch.cumsum(seglen_counts, dim=0)
total = int(seg[-1].item())
# 3) scatter — produce valid_kernel_idx as concatenated ascending set-bit positions for each reduced_code row
if total == 0:
valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device)
return valid_kernel_idx, seg
max_val = int(reduced_code.max().item())
V = max_val.bit_length() if max_val > 0 else 0
# If you know V externally, pass it instead or set here explicitly.
if V == 0:
# no bits set anywhere
valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device)
return valid_kernel_idx, seg
# build mask of shape (num_blocks, V): True where bit is set
bit_pos = torch.arange(0, V, dtype=torch.int64, device=device) # [V]
# shifted = reduced_code[:, None] >> bit_pos[None, :]
shifted = reduced_code.to(torch.int64).unsqueeze(1) >> bit_pos.unsqueeze(0)
bits = (shifted & 1).to(torch.bool) # (num_blocks, V)
positions = bit_pos.unsqueeze(0).expand(num_blocks, V)
valid_positions = positions[bits]
valid_kernel_idx = valid_positions.to(torch.int32).contiguous()
return valid_kernel_idx, seg
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
if len(shape) == 5:
N, C, W, H, D = shape
else:
W, H, D = shape
Co, Kw, Kh, Kd, Ci = weight.shape
b_stride = W * H * D
x_stride = H * D
y_stride = D
z_stride = 1
flat_keys = (coords[:, 0].long() * b_stride +
coords[:, 1].long() * x_stride +
coords[:, 2].long() * y_stride +
coords[:, 3].long() * z_stride)
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=coords.device)
hashmap = TorchHashMap(flat_keys, vals, 0xFFFFFFFF)
if neighbor_cache is None:
neighbor = build_submanifold_neighbor_map(
hashmap, coords, W, H, D, Kw, Kh, Kd,
dilation[0], dilation[1], dilation[2]
)
else:
neighbor = neighbor_cache
block_size = 128
gray_code, sorted_idx, valid_signal_i, valid_signal_o, valid_signal_seg = \
neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor)
valid_kernel, valid_kernel_seg = \
neighbor_map_post_process_for_masked_implicit_gemm_2(gray_code, sorted_idx, block_size)
valid_kernel_fn = lambda b_size: valid_kernel
valid_kernel_seg_fn = lambda b_size: valid_kernel_seg
weight_flat = weight.contiguous().view(Co, -1, Ci)
out = sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk(
feats,
weight_flat,
bias,
neighbor,
sorted_idx,
valid_kernel_fn,
valid_kernel_seg_fn
)
return out, neighbor
class Voxel:
def __init__(
self,

View File

@ -408,7 +408,7 @@ class SLatFlowModel(nn.Module):
self.qk_rms_norm_cross = qk_rms_norm_cross
self.dtype = dtype
self.t_embedder = TimestepEmbedder(model_channels)
self.t_embedder = TimestepEmbedder(model_channels, device=device, dtype=dtype, operations=operations)
if share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
@ -485,15 +485,25 @@ class Trellis2(nn.Module):
qk_rms_norm = True,
qk_rms_norm_cross = True,
dtype=None, device=None, operations=None):
super().__init__()
args = {
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
}
# TODO: update the names/checkpoints
self.img2shape = SLatFlowModel(resolution, in_channels=in_channels, *args)
self.shape2txt = SLatFlowModel(resolution, in_channels=in_channels*2, *args)
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
self.shape_generation = True
def forward(self, x, timestep, context):
pass
def forward(self, x, timestep, context, **kwargs):
# TODO add mode
mode = kwargs.get("mode", "shape_generation")
mode = "texture_generation" if mode == 1 else "shape_generation"
if mode == "shape_generation":
out = self.img2shape(x, timestep, context)
if mode == "texture_generation":
out = self.shape2txt(x, timestep, context)
return out

View File

@ -1,3 +1,4 @@
import math
import torch
import torch.nn as nn
from typing import List, Any, Dict, Optional, overload, Union, Tuple
@ -5,12 +6,219 @@ from fractions import Fraction
import torch.nn.functional as F
from dataclasses import dataclass
import numpy as np
from cumesh import TorchHashMap, Mesh, MeshWithVoxel
from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
class SparseConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
super(SparseConv3d, self).__init__()
sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key)
def forward(self, x):
return sparse_conv3d_forward(self, x)
def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size, ) * 3
self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, ) * 3
self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation, ) * 3
self.weight = nn.Parameter(torch.empty((out_channels, in_channels, *self.kernel_size)))
if bias:
self.bias = nn.Parameter(torch.empty(out_channels))
else:
self.register_parameter("bias", None)
if self.bias is not None:
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
torch.nn.init.uniform_(self.bias, -bound, bound)
# Permute weight (Co, Ci, Kd, Kh, Kw) -> (Co, Kd, Kh, Kw, Ci)
self.weight = nn.Parameter(self.weight.permute(0, 2, 3, 4, 1).contiguous())
def sparse_conv3d_forward(self, x):
# check if neighbor map is already computed
Co, Kd, Kh, Kw, Ci = self.weight.shape
neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}'
neighbor_cache = x.get_spatial_cache(neighbor_cache_key)
out, neighbor_cache_ = sparse_submanifold_conv3d(
x.feats,
x.coords,
torch.Size([*x.shape, *x.spatial_shape]),
self.weight,
self.bias,
neighbor_cache,
self.dilation
)
if neighbor_cache is None:
x.register_spatial_cache(neighbor_cache_key, neighbor_cache_)
out = x.replace(out)
return out
class LayerNorm32(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dtype = x.dtype
x = x.to(torch.float32)
o = super().forward(x)
return o.to(x_dtype)
class SparseConvNeXtBlock3d(nn.Module):
def __init__(
self,
channels: int,
mlp_ratio: float = 4.0,
use_checkpoint: bool = False,
):
super().__init__()
self.channels = channels
self.use_checkpoint = use_checkpoint
self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.conv = SparseConv3d(channels, channels, 3)
self.mlp = nn.Sequential(
nn.Linear(channels, int(channels * mlp_ratio)),
nn.SiLU(),
nn.Linear(int(channels * mlp_ratio), channels),
)
def _forward(self, x):
h = self.conv(x)
h = h.replace(self.norm(h.feats))
h = h.replace(self.mlp(h.feats))
return h + x
def forward(self, x):
return self._forward(x)
class SparseSpatial2Channel(nn.Module):
def __init__(self, factor: int = 2):
super(SparseSpatial2Channel, self).__init__()
self.factor = factor
def forward(self, x):
DIM = x.coords.shape[-1] - 1
cache = x.get_spatial_cache(f'spatial2channel_{self.factor}')
if cache is None:
coord = list(x.coords.unbind(dim=-1))
for i in range(DIM):
coord[i+1] = coord[i+1] // self.factor
subidx = x.coords[:, 1:] % self.factor
subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)])
MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape]
OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
code = sum([c * o for c, o in zip(coord, OFFSET)])
code, idx = code.unique(return_inverse=True)
new_coords = torch.stack(
[code // OFFSET[0]] +
[(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
dim=-1
)
else:
new_coords, idx, subidx = cache
new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype)
new_feats[idx * self.factor ** DIM + subidx] = x.feats
out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM]))
out._scale = tuple([s * self.factor for s in x._scale])
out._spatial_cache = x._spatial_cache
if cache is None:
x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx))
out.register_spatial_cache(f'channel2spatial_{self.factor}', (x.coords, idx, subidx))
out.register_spatial_cache('shape', torch.Size(MAX))
return out
class SparseChannel2Spatial(nn.Module):
def __init__(self, factor: int = 2):
super(SparseChannel2Spatial, self).__init__()
self.factor = factor
def forward(self, x, subdivision = None):
DIM = x.coords.shape[-1] - 1
cache = x.get_spatial_cache(f'channel2spatial_{self.factor}')
if cache is None:
if subdivision is None:
raise ValueError('Cache not found. Provide subdivision tensor or pair SparseChannel2Spatial with SparseSpatial2Channel.')
else:
sub = subdivision.feats # [N, self.factor ** DIM]
N_leaf = sub.sum(dim=-1) # [N]
subidx = sub.nonzero()[:, -1]
new_coords = x.coords.clone().detach()
new_coords[:, 1:] *= self.factor
new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0])
for i in range(DIM):
new_coords[:, i+1] += subidx // self.factor ** i % self.factor
idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0])
else:
new_coords, idx, subidx = cache
x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1)
new_feats = x_feats[idx * self.factor ** DIM + subidx]
out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM]))
out._scale = tuple([s / self.factor for s in x._scale])
if cache is not None: # only keep cache when subdiv following it
out._spatial_cache = x._spatial_cache
return out
class SparseResBlockC2S3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
use_checkpoint: bool = False,
pred_subdiv: bool = True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_checkpoint = use_checkpoint
self.pred_subdiv = pred_subdiv
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
self.conv1 = SparseConv3d(channels, self.out_channels * 8, 3)
self.conv2 = SparseConv3d(self.out_channels, self.out_channels, 3)
self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1))
if pred_subdiv:
self.to_subdiv = SparseLinear(channels, 8)
self.updown = SparseChannel2Spatial(2)
def _forward(self, x, subdiv = None):
if self.pred_subdiv:
subdiv = self.to_subdiv(x)
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
h = self.conv1(h)
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
h = self.updown(h, subdiv_binarized)
x = self.updown(x, subdiv_binarized)
h = h.replace(self.norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x)
if self.pred_subdiv:
return h, subdiv
else:
return h
# TODO: determine which conv they actually use
@dataclass
class config:
CONV = "none"
CONV = "flexgemm"
FLEX_GEMM_HASHMAP_RATIO = 2.0
# TODO post processing
def simplify(self, target_num_faces: int, verbose: bool=False, options: dict={}):
@ -1131,6 +1339,7 @@ def flexible_dual_grid_to_mesh(
class Vae(nn.Module):
def __init__(self, config, operations=None):
super().__init__()
operations = operations or torch.nn
self.txt_dec = SparseUnetVaeDecoder(
@ -1139,7 +1348,8 @@ class Vae(nn.Module):
latent_channels=32,
num_blocks=[4, 16, 8, 4, 0],
block_type=["SparseConvNeXtBlock3d"] * 5,
up_block_type=["SparseResBlockS2C3d"] * 4,
up_block_type=["SparseResBlockC2S3d"] * 4,
block_args=[{}, {}, {}, {}, {}],
pred_subdiv=False
)
@ -1149,7 +1359,8 @@ class Vae(nn.Module):
latent_channels=32,
num_blocks=[4, 16, 8, 4, 0],
block_type=["SparseConvNeXtBlock3d"] * 5,
up_block_type=["SparseResBlockS2C3d"] * 4,
up_block_type=["SparseResBlockC2S3d"] * 4,
block_args=[{}, {}, {}, {}, {}],
)
def decode_shape_slat(self, slat, resolution: int):

View File

@ -50,6 +50,7 @@ import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
import comfy.ldm.trellis2.model
import comfy.model_management
import comfy.patcher_extension
@ -1455,6 +1456,13 @@ class WAN22(WAN21):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class Trellis2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None, unet_model=comfy.ldm.trellis2.model.Trellis2):
super().__init__(model_config, model_type, device, unet_model)
def extra_conds(self, **kwargs):
return super().extra_conds(**kwargs)
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)

View File

@ -1242,6 +1242,17 @@ class WAN22_T2V(WAN21_T2V):
out = model_base.WAN22(self, image_to_video=True, device=device)
return out
class Trellis2(supported_models_base.BASE):
unet_config = {
"image_model": "trellis2"
}
latent_format = latent_formats.Trellis2
vae_key_prefix = ["vae."]
def get_model(self, state_dict, prefix="", device=None):
return model_base.Trellis2(self, device=device)
class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",
@ -1596,6 +1607,6 @@ class Kandinsky5Image(Kandinsky5):
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, Trellis2]
models += [SVD_img2vid]