mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
model init working
This commit is contained in:
parent
f76e3a11b5
commit
d6573fd26d
@ -746,6 +746,8 @@ class Hunyuan3Dv2_1(LatentFormat):
|
|||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
|
|
||||||
|
class Trellis2(LatentFormat): # TODO
|
||||||
|
latent_channels = 32
|
||||||
class Hunyuan3Dv2mini(LatentFormat):
|
class Hunyuan3Dv2mini(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
|
|||||||
@ -1,8 +1,192 @@
|
|||||||
# will contain every cuda -> pytorch operation
|
# will contain every cuda -> pytorch operation
|
||||||
|
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
from typing import Dict
|
from typing import Dict, Callable
|
||||||
|
|
||||||
|
NO_TRITION = False
|
||||||
|
try:
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
heuristics = {
|
||||||
|
'valid_kernel': lambda args: args['valid_kernel'](args['B1']),
|
||||||
|
'valid_kernel_seg': lambda args: args['valid_kernel_seg'](args['B1']),
|
||||||
|
}
|
||||||
|
|
||||||
|
#@triton_autotune(
|
||||||
|
# configs=config.autotune_config,
|
||||||
|
# key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'],
|
||||||
|
#)
|
||||||
|
@triton.heuristics(heuristics)
|
||||||
|
@triton.jit
|
||||||
|
def sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel(
|
||||||
|
input,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
neighbor,
|
||||||
|
sorted_idx,
|
||||||
|
output,
|
||||||
|
# Tensor dimensions
|
||||||
|
N, LOGN, Ci, Co, V: tl.constexpr,
|
||||||
|
# Meta-parameters
|
||||||
|
B1: tl.constexpr, # Block size for N dimension
|
||||||
|
B2: tl.constexpr, # Block size for Co dimension
|
||||||
|
BK: tl.constexpr, # Block size for K dimension (V * Ci)
|
||||||
|
allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
|
||||||
|
# Huristic parameters
|
||||||
|
valid_kernel,
|
||||||
|
valid_kernel_seg,
|
||||||
|
):
|
||||||
|
|
||||||
|
block_id = tl.program_id(axis=0)
|
||||||
|
block_dim_co = tl.cdiv(Co, B2)
|
||||||
|
block_id_co = block_id % block_dim_co
|
||||||
|
block_id_n = block_id // block_dim_co
|
||||||
|
|
||||||
|
# Create pointers for submatrices of A and B.
|
||||||
|
num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension
|
||||||
|
valid_kernel_start = tl.load(valid_kernel_seg + block_id_n)
|
||||||
|
valid_kernel_seglen = tl.load(valid_kernel_seg + block_id_n + 1) - valid_kernel_start
|
||||||
|
offset_n = block_id_n * B1 + tl.arange(0, B1)
|
||||||
|
n_mask = offset_n < N
|
||||||
|
offset_sorted_n = tl.load(sorted_idx + offset_n, mask=n_mask, other=0) # (B1,)
|
||||||
|
offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,)
|
||||||
|
offset_k = tl.arange(0, BK) # (BK,)
|
||||||
|
|
||||||
|
# Create a block of the output matrix C.
|
||||||
|
accumulator = tl.zeros((B1, B2), dtype=tl.float32)
|
||||||
|
|
||||||
|
# Iterate along V*Ci dimension.
|
||||||
|
for k in range(num_k * valid_kernel_seglen):
|
||||||
|
v = k // num_k
|
||||||
|
bk = k % num_k
|
||||||
|
v = tl.load(valid_kernel + valid_kernel_start + v)
|
||||||
|
# Calculate pointers to input matrix.
|
||||||
|
neighbor_offset_n = tl.load(neighbor + offset_sorted_n * V + v) # (B1,)
|
||||||
|
input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK)
|
||||||
|
# Calculate pointers to weight matrix.
|
||||||
|
weight_ptr = weight + v * Ci + bk * BK + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2)
|
||||||
|
# Load the next block of input and weight.
|
||||||
|
neigh_mask = neighbor_offset_n != 0xffffffff
|
||||||
|
k_mask = offset_k < Ci - bk * BK
|
||||||
|
input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
|
||||||
|
weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
|
||||||
|
# Accumulate along the K dimension.
|
||||||
|
accumulator = tl.dot(input_block, weight_block, accumulator,
|
||||||
|
input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
|
||||||
|
c = accumulator.to(input.type.element_ty)
|
||||||
|
|
||||||
|
# add bias
|
||||||
|
if bias is not None:
|
||||||
|
bias_block = tl.load(bias + offset_co)
|
||||||
|
c += bias_block[None, :]
|
||||||
|
|
||||||
|
# Write back the block of the output matrix with masks.
|
||||||
|
out_offset_n = offset_sorted_n
|
||||||
|
out_offset_co = block_id_co * B2 + tl.arange(0, B2)
|
||||||
|
out_ptr = output + (out_offset_n[:, None] * Co + out_offset_co[None, :])
|
||||||
|
out_mask = n_mask[:, None] & (out_offset_co[None, :] < Co)
|
||||||
|
tl.store(out_ptr, c, mask=out_mask)
|
||||||
|
def sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
neighbor: torch.Tensor,
|
||||||
|
sorted_idx: torch.Tensor,
|
||||||
|
valid_kernel: Callable[[int], torch.Tensor],
|
||||||
|
valid_kernel_seg: Callable[[int], torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
|
||||||
|
LOGN = int(math.log2(N))
|
||||||
|
output = torch.empty((N, Co), device=input.device, dtype=input.dtype)
|
||||||
|
grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),)
|
||||||
|
sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid](
|
||||||
|
input, weight, bias, neighbor, sorted_idx, output,
|
||||||
|
N, LOGN, Ci, Co, V, #
|
||||||
|
valid_kernel=valid_kernel,
|
||||||
|
valid_kernel_seg=valid_kernel_seg,
|
||||||
|
allow_tf32=torch.cuda.is_tf32_supported(),
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
except:
|
||||||
|
NO_TRITION = True
|
||||||
|
|
||||||
|
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
|
||||||
|
# offsets in same order as CUDA kernel
|
||||||
|
offsets = []
|
||||||
|
for vx in range(Kw):
|
||||||
|
for vy in range(Kh):
|
||||||
|
for vz in range(Kd):
|
||||||
|
offsets.append((
|
||||||
|
vx * Dw,
|
||||||
|
vy * Dh,
|
||||||
|
vz * Dd
|
||||||
|
))
|
||||||
|
return torch.tensor(offsets, device=device)
|
||||||
|
|
||||||
|
def build_submanifold_neighbor_map(
|
||||||
|
hashmap,
|
||||||
|
coords: torch.Tensor,
|
||||||
|
W, H, D,
|
||||||
|
Kw, Kh, Kd,
|
||||||
|
Dw, Dh, Dd,
|
||||||
|
):
|
||||||
|
device = coords.device
|
||||||
|
M = coords.shape[0]
|
||||||
|
V = Kw * Kh * Kd
|
||||||
|
half_V = V // 2 + 1
|
||||||
|
|
||||||
|
INVALID = hashmap.default_value
|
||||||
|
|
||||||
|
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
b = coords[:, 0]
|
||||||
|
x = coords[:, 1]
|
||||||
|
y = coords[:, 2]
|
||||||
|
z = coords[:, 3]
|
||||||
|
|
||||||
|
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device)
|
||||||
|
|
||||||
|
ox = x[:, None] - (Kw // 2) * Dw
|
||||||
|
oy = y[:, None] - (Kh // 2) * Dh
|
||||||
|
oz = z[:, None] - (Kd // 2) * Dd
|
||||||
|
|
||||||
|
for v in range(half_V):
|
||||||
|
if v == half_V - 1:
|
||||||
|
neighbor[:, v] = torch.arange(M, device=device)
|
||||||
|
continue
|
||||||
|
|
||||||
|
dx, dy, dz = offsets[v]
|
||||||
|
|
||||||
|
kx = ox[:, v] + dx
|
||||||
|
ky = oy[:, v] + dy
|
||||||
|
kz = oz[:, v] + dz
|
||||||
|
|
||||||
|
valid = (
|
||||||
|
(kx >= 0) & (kx < W) &
|
||||||
|
(ky >= 0) & (ky < H) &
|
||||||
|
(kz >= 0) & (kz < D)
|
||||||
|
)
|
||||||
|
|
||||||
|
flat = (
|
||||||
|
b * (W * H * D) +
|
||||||
|
kx * (H * D) +
|
||||||
|
ky * D +
|
||||||
|
kz
|
||||||
|
)
|
||||||
|
|
||||||
|
flat = flat[valid]
|
||||||
|
idx = torch.nonzero(valid, as_tuple=False).squeeze(1)
|
||||||
|
|
||||||
|
found = hashmap.lookup_flat(flat)
|
||||||
|
|
||||||
|
neighbor[idx, v] = found
|
||||||
|
|
||||||
|
# symmetric write
|
||||||
|
valid_found = found != INVALID
|
||||||
|
neighbor[found[valid_found], V - 1 - v] = idx[valid_found]
|
||||||
|
|
||||||
|
return neighbor
|
||||||
|
|
||||||
class TorchHashMap:
|
class TorchHashMap:
|
||||||
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
|
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
|
||||||
@ -22,6 +206,207 @@ class TorchHashMap:
|
|||||||
out[found] = self.sorted_vals[idx[found]]
|
out[found] = self.sorted_vals[idx[found]]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
UINT32_SENTINEL = 0xFFFFFFFF
|
||||||
|
|
||||||
|
def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map):
|
||||||
|
device = neighbor_map.device
|
||||||
|
N, V = neighbor_map.shape
|
||||||
|
|
||||||
|
|
||||||
|
neigh = neighbor_map.to(torch.long)
|
||||||
|
sentinel = torch.tensor(UINT32_SENTINEL, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
neigh_map_T = neigh.t().reshape(-1)
|
||||||
|
|
||||||
|
neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32)
|
||||||
|
|
||||||
|
mask = (neigh != sentinel).to(torch.long)
|
||||||
|
|
||||||
|
powers = (1 << torch.arange(V, dtype=torch.long, device=device))
|
||||||
|
|
||||||
|
gray_long = (mask * powers).sum(dim=1)
|
||||||
|
|
||||||
|
gray_code = gray_long.to(torch.int32)
|
||||||
|
|
||||||
|
binary_long = gray_long.clone()
|
||||||
|
for v in range(1, V):
|
||||||
|
binary_long ^= (gray_long >> v)
|
||||||
|
binary_code = binary_long.to(torch.int32)
|
||||||
|
|
||||||
|
sorted_idx = torch.argsort(binary_code)
|
||||||
|
|
||||||
|
prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T.to(torch.int32), dim=0) # (V*N,)
|
||||||
|
|
||||||
|
total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0
|
||||||
|
|
||||||
|
if total_valid_signal > 0:
|
||||||
|
valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
|
||||||
|
valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0]
|
||||||
|
|
||||||
|
to = (prefix_sum_neighbor_mask[pos] - 1).to(torch.long)
|
||||||
|
|
||||||
|
valid_signal_i[to] = (pos % N).to(torch.long)
|
||||||
|
|
||||||
|
valid_signal_o[to] = neigh_map_T[pos].to(torch.long)
|
||||||
|
else:
|
||||||
|
valid_signal_i = torch.empty((0,), dtype=torch.long, device=device)
|
||||||
|
valid_signal_o = torch.empty((0,), dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
seg = torch.empty((V + 1,), dtype=torch.long, device=device)
|
||||||
|
seg[0] = 0
|
||||||
|
if V > 0:
|
||||||
|
idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1
|
||||||
|
seg[1:] = prefix_sum_neighbor_mask[idxs].to(torch.long)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg
|
||||||
|
|
||||||
|
def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
x = x.to(torch.int64)
|
||||||
|
|
||||||
|
m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device)
|
||||||
|
m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device)
|
||||||
|
m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device)
|
||||||
|
h01 = torch.tensor(0x0101010101010101, dtype=torch.int64, device=x.device)
|
||||||
|
|
||||||
|
x = x - ((x >> 1) & m1)
|
||||||
|
x = (x & m2) + ((x >> 2) & m2)
|
||||||
|
x = (x + (x >> 4)) & m4
|
||||||
|
x = (x * h01) >> 56
|
||||||
|
return x.to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def neighbor_map_post_process_for_masked_implicit_gemm_2(
|
||||||
|
gray_code: torch.Tensor, # [N], int32-like (non-negative)
|
||||||
|
sorted_idx: torch.Tensor, # [N], long (indexing into gray_code)
|
||||||
|
block_size: int
|
||||||
|
):
|
||||||
|
device = gray_code.device
|
||||||
|
N = gray_code.numel()
|
||||||
|
|
||||||
|
# num of blocks (same as CUDA)
|
||||||
|
num_blocks = (N + block_size - 1) // block_size
|
||||||
|
|
||||||
|
# Ensure dtypes
|
||||||
|
gray_long = gray_code.to(torch.int64) # safer to OR in 64-bit then cast
|
||||||
|
sorted_idx = sorted_idx.to(torch.long)
|
||||||
|
|
||||||
|
# 1) Group gray_code by blocks and compute OR across each block
|
||||||
|
# pad the last block with zeros if necessary so we can reshape
|
||||||
|
pad = num_blocks * block_size - N
|
||||||
|
if pad > 0:
|
||||||
|
pad_vals = torch.zeros((pad,), dtype=torch.int64, device=device)
|
||||||
|
gray_padded = torch.cat([gray_long[sorted_idx], pad_vals], dim=0)
|
||||||
|
else:
|
||||||
|
gray_padded = gray_long[sorted_idx]
|
||||||
|
|
||||||
|
# reshape to (num_blocks, block_size) and compute bitwise_or across dim=1
|
||||||
|
gray_blocks = gray_padded.view(num_blocks, block_size) # each row = block entries
|
||||||
|
# reduce with bitwise_or
|
||||||
|
reduced_code = gray_blocks[:, 0].clone()
|
||||||
|
for i in range(1, block_size):
|
||||||
|
reduced_code |= gray_blocks[:, i]
|
||||||
|
reduced_code = reduced_code.to(torch.int32) # match CUDA int32
|
||||||
|
|
||||||
|
# 2) compute seglen (popcount per reduced_code) and seg (prefix sum)
|
||||||
|
seglen_counts = _popcount_int32_tensor(reduced_code.to(torch.int64)).to(torch.int32) # [num_blocks]
|
||||||
|
# seg: length num_blocks+1, seg[0]=0, seg[i+1]=cumsum(seglen_counts) up to i
|
||||||
|
seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device)
|
||||||
|
seg[0] = 0
|
||||||
|
if num_blocks > 0:
|
||||||
|
seg[1:] = torch.cumsum(seglen_counts, dim=0)
|
||||||
|
|
||||||
|
total = int(seg[-1].item())
|
||||||
|
|
||||||
|
# 3) scatter — produce valid_kernel_idx as concatenated ascending set-bit positions for each reduced_code row
|
||||||
|
if total == 0:
|
||||||
|
valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device)
|
||||||
|
return valid_kernel_idx, seg
|
||||||
|
|
||||||
|
max_val = int(reduced_code.max().item())
|
||||||
|
V = max_val.bit_length() if max_val > 0 else 0
|
||||||
|
# If you know V externally, pass it instead or set here explicitly.
|
||||||
|
|
||||||
|
if V == 0:
|
||||||
|
# no bits set anywhere
|
||||||
|
valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device)
|
||||||
|
return valid_kernel_idx, seg
|
||||||
|
|
||||||
|
# build mask of shape (num_blocks, V): True where bit is set
|
||||||
|
bit_pos = torch.arange(0, V, dtype=torch.int64, device=device) # [V]
|
||||||
|
# shifted = reduced_code[:, None] >> bit_pos[None, :]
|
||||||
|
shifted = reduced_code.to(torch.int64).unsqueeze(1) >> bit_pos.unsqueeze(0)
|
||||||
|
bits = (shifted & 1).to(torch.bool) # (num_blocks, V)
|
||||||
|
|
||||||
|
positions = bit_pos.unsqueeze(0).expand(num_blocks, V)
|
||||||
|
|
||||||
|
valid_positions = positions[bits]
|
||||||
|
valid_kernel_idx = valid_positions.to(torch.int32).contiguous()
|
||||||
|
|
||||||
|
return valid_kernel_idx, seg
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
|
||||||
|
if len(shape) == 5:
|
||||||
|
N, C, W, H, D = shape
|
||||||
|
else:
|
||||||
|
W, H, D = shape
|
||||||
|
|
||||||
|
Co, Kw, Kh, Kd, Ci = weight.shape
|
||||||
|
|
||||||
|
b_stride = W * H * D
|
||||||
|
x_stride = H * D
|
||||||
|
y_stride = D
|
||||||
|
z_stride = 1
|
||||||
|
|
||||||
|
flat_keys = (coords[:, 0].long() * b_stride +
|
||||||
|
coords[:, 1].long() * x_stride +
|
||||||
|
coords[:, 2].long() * y_stride +
|
||||||
|
coords[:, 3].long() * z_stride)
|
||||||
|
|
||||||
|
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=coords.device)
|
||||||
|
|
||||||
|
hashmap = TorchHashMap(flat_keys, vals, 0xFFFFFFFF)
|
||||||
|
|
||||||
|
if neighbor_cache is None:
|
||||||
|
neighbor = build_submanifold_neighbor_map(
|
||||||
|
hashmap, coords, W, H, D, Kw, Kh, Kd,
|
||||||
|
dilation[0], dilation[1], dilation[2]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
neighbor = neighbor_cache
|
||||||
|
|
||||||
|
block_size = 128
|
||||||
|
|
||||||
|
gray_code, sorted_idx, valid_signal_i, valid_signal_o, valid_signal_seg = \
|
||||||
|
neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor)
|
||||||
|
|
||||||
|
valid_kernel, valid_kernel_seg = \
|
||||||
|
neighbor_map_post_process_for_masked_implicit_gemm_2(gray_code, sorted_idx, block_size)
|
||||||
|
|
||||||
|
valid_kernel_fn = lambda b_size: valid_kernel
|
||||||
|
valid_kernel_seg_fn = lambda b_size: valid_kernel_seg
|
||||||
|
|
||||||
|
weight_flat = weight.contiguous().view(Co, -1, Ci)
|
||||||
|
|
||||||
|
out = sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk(
|
||||||
|
feats,
|
||||||
|
weight_flat,
|
||||||
|
bias,
|
||||||
|
neighbor,
|
||||||
|
sorted_idx,
|
||||||
|
valid_kernel_fn,
|
||||||
|
valid_kernel_seg_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
return out, neighbor
|
||||||
|
|
||||||
class Voxel:
|
class Voxel:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -408,7 +408,7 @@ class SLatFlowModel(nn.Module):
|
|||||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
self.t_embedder = TimestepEmbedder(model_channels)
|
self.t_embedder = TimestepEmbedder(model_channels, device=device, dtype=dtype, operations=operations)
|
||||||
if share_mod:
|
if share_mod:
|
||||||
self.adaLN_modulation = nn.Sequential(
|
self.adaLN_modulation = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
@ -485,15 +485,25 @@ class Trellis2(nn.Module):
|
|||||||
qk_rms_norm = True,
|
qk_rms_norm = True,
|
||||||
qk_rms_norm_cross = True,
|
qk_rms_norm_cross = True,
|
||||||
dtype=None, device=None, operations=None):
|
dtype=None, device=None, operations=None):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
args = {
|
args = {
|
||||||
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
|
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
|
||||||
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
|
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
|
||||||
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
|
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
|
||||||
}
|
}
|
||||||
# TODO: update the names/checkpoints
|
# TODO: update the names/checkpoints
|
||||||
self.img2shape = SLatFlowModel(resolution, in_channels=in_channels, *args)
|
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
|
||||||
self.shape2txt = SLatFlowModel(resolution, in_channels=in_channels*2, *args)
|
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
||||||
self.shape_generation = True
|
self.shape_generation = True
|
||||||
|
|
||||||
def forward(self, x, timestep, context):
|
def forward(self, x, timestep, context, **kwargs):
|
||||||
pass
|
# TODO add mode
|
||||||
|
mode = kwargs.get("mode", "shape_generation")
|
||||||
|
mode = "texture_generation" if mode == 1 else "shape_generation"
|
||||||
|
if mode == "shape_generation":
|
||||||
|
out = self.img2shape(x, timestep, context)
|
||||||
|
if mode == "texture_generation":
|
||||||
|
out = self.shape2txt(x, timestep, context)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import List, Any, Dict, Optional, overload, Union, Tuple
|
from typing import List, Any, Dict, Optional, overload, Union, Tuple
|
||||||
@ -5,12 +6,219 @@ from fractions import Fraction
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from cumesh import TorchHashMap, Mesh, MeshWithVoxel
|
from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
|
||||||
|
|
||||||
|
|
||||||
|
class SparseConv3d(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
|
||||||
|
super(SparseConv3d, self).__init__()
|
||||||
|
sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return sparse_conv3d_forward(self, x)
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size, ) * 3
|
||||||
|
self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, ) * 3
|
||||||
|
self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation, ) * 3
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.empty((out_channels, in_channels, *self.kernel_size)))
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.empty(out_channels))
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||||
|
if fan_in != 0:
|
||||||
|
bound = 1 / math.sqrt(fan_in)
|
||||||
|
torch.nn.init.uniform_(self.bias, -bound, bound)
|
||||||
|
|
||||||
|
# Permute weight (Co, Ci, Kd, Kh, Kw) -> (Co, Kd, Kh, Kw, Ci)
|
||||||
|
self.weight = nn.Parameter(self.weight.permute(0, 2, 3, 4, 1).contiguous())
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_conv3d_forward(self, x):
|
||||||
|
# check if neighbor map is already computed
|
||||||
|
Co, Kd, Kh, Kw, Ci = self.weight.shape
|
||||||
|
neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}'
|
||||||
|
neighbor_cache = x.get_spatial_cache(neighbor_cache_key)
|
||||||
|
|
||||||
|
out, neighbor_cache_ = sparse_submanifold_conv3d(
|
||||||
|
x.feats,
|
||||||
|
x.coords,
|
||||||
|
torch.Size([*x.shape, *x.spatial_shape]),
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
neighbor_cache,
|
||||||
|
self.dilation
|
||||||
|
)
|
||||||
|
|
||||||
|
if neighbor_cache is None:
|
||||||
|
x.register_spatial_cache(neighbor_cache_key, neighbor_cache_)
|
||||||
|
|
||||||
|
out = x.replace(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class LayerNorm32(nn.LayerNorm):
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
o = super().forward(x)
|
||||||
|
return o.to(x_dtype)
|
||||||
|
|
||||||
|
class SparseConvNeXtBlock3d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
|
||||||
|
self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.conv = SparseConv3d(channels, channels, 3)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(channels, int(channels * mlp_ratio)),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(int(channels * mlp_ratio), channels),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward(self, x):
|
||||||
|
h = self.conv(x)
|
||||||
|
h = h.replace(self.norm(h.feats))
|
||||||
|
h = h.replace(self.mlp(h.feats))
|
||||||
|
return h + x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self._forward(x)
|
||||||
|
|
||||||
|
class SparseSpatial2Channel(nn.Module):
|
||||||
|
def __init__(self, factor: int = 2):
|
||||||
|
super(SparseSpatial2Channel, self).__init__()
|
||||||
|
self.factor = factor
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
DIM = x.coords.shape[-1] - 1
|
||||||
|
cache = x.get_spatial_cache(f'spatial2channel_{self.factor}')
|
||||||
|
if cache is None:
|
||||||
|
coord = list(x.coords.unbind(dim=-1))
|
||||||
|
for i in range(DIM):
|
||||||
|
coord[i+1] = coord[i+1] // self.factor
|
||||||
|
subidx = x.coords[:, 1:] % self.factor
|
||||||
|
subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)])
|
||||||
|
|
||||||
|
MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape]
|
||||||
|
OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
|
||||||
|
code = sum([c * o for c, o in zip(coord, OFFSET)])
|
||||||
|
code, idx = code.unique(return_inverse=True)
|
||||||
|
|
||||||
|
new_coords = torch.stack(
|
||||||
|
[code // OFFSET[0]] +
|
||||||
|
[(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
|
||||||
|
dim=-1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_coords, idx, subidx = cache
|
||||||
|
|
||||||
|
new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype)
|
||||||
|
new_feats[idx * self.factor ** DIM + subidx] = x.feats
|
||||||
|
|
||||||
|
out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM]))
|
||||||
|
out._scale = tuple([s * self.factor for s in x._scale])
|
||||||
|
out._spatial_cache = x._spatial_cache
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx))
|
||||||
|
out.register_spatial_cache(f'channel2spatial_{self.factor}', (x.coords, idx, subidx))
|
||||||
|
out.register_spatial_cache('shape', torch.Size(MAX))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class SparseChannel2Spatial(nn.Module):
|
||||||
|
def __init__(self, factor: int = 2):
|
||||||
|
super(SparseChannel2Spatial, self).__init__()
|
||||||
|
self.factor = factor
|
||||||
|
|
||||||
|
def forward(self, x, subdivision = None):
|
||||||
|
DIM = x.coords.shape[-1] - 1
|
||||||
|
|
||||||
|
cache = x.get_spatial_cache(f'channel2spatial_{self.factor}')
|
||||||
|
if cache is None:
|
||||||
|
if subdivision is None:
|
||||||
|
raise ValueError('Cache not found. Provide subdivision tensor or pair SparseChannel2Spatial with SparseSpatial2Channel.')
|
||||||
|
else:
|
||||||
|
sub = subdivision.feats # [N, self.factor ** DIM]
|
||||||
|
N_leaf = sub.sum(dim=-1) # [N]
|
||||||
|
subidx = sub.nonzero()[:, -1]
|
||||||
|
new_coords = x.coords.clone().detach()
|
||||||
|
new_coords[:, 1:] *= self.factor
|
||||||
|
new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0])
|
||||||
|
for i in range(DIM):
|
||||||
|
new_coords[:, i+1] += subidx // self.factor ** i % self.factor
|
||||||
|
idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0])
|
||||||
|
else:
|
||||||
|
new_coords, idx, subidx = cache
|
||||||
|
|
||||||
|
x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1)
|
||||||
|
new_feats = x_feats[idx * self.factor ** DIM + subidx]
|
||||||
|
out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM]))
|
||||||
|
out._scale = tuple([s / self.factor for s in x._scale])
|
||||||
|
if cache is not None: # only keep cache when subdiv following it
|
||||||
|
out._spatial_cache = x._spatial_cache
|
||||||
|
return out
|
||||||
|
|
||||||
|
class SparseResBlockC2S3d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
out_channels: Optional[int] = None,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
pred_subdiv: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.out_channels = out_channels or channels
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.pred_subdiv = pred_subdiv
|
||||||
|
|
||||||
|
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.conv1 = SparseConv3d(channels, self.out_channels * 8, 3)
|
||||||
|
self.conv2 = SparseConv3d(self.out_channels, self.out_channels, 3)
|
||||||
|
self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1))
|
||||||
|
if pred_subdiv:
|
||||||
|
self.to_subdiv = SparseLinear(channels, 8)
|
||||||
|
self.updown = SparseChannel2Spatial(2)
|
||||||
|
|
||||||
|
def _forward(self, x, subdiv = None):
|
||||||
|
if self.pred_subdiv:
|
||||||
|
subdiv = self.to_subdiv(x)
|
||||||
|
h = x.replace(self.norm1(x.feats))
|
||||||
|
h = h.replace(F.silu(h.feats))
|
||||||
|
h = self.conv1(h)
|
||||||
|
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
|
||||||
|
h = self.updown(h, subdiv_binarized)
|
||||||
|
x = self.updown(x, subdiv_binarized)
|
||||||
|
h = h.replace(self.norm2(h.feats))
|
||||||
|
h = h.replace(F.silu(h.feats))
|
||||||
|
h = self.conv2(h)
|
||||||
|
h = h + self.skip_connection(x)
|
||||||
|
if self.pred_subdiv:
|
||||||
|
return h, subdiv
|
||||||
|
else:
|
||||||
|
return h
|
||||||
|
|
||||||
# TODO: determine which conv they actually use
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class config:
|
class config:
|
||||||
CONV = "none"
|
CONV = "flexgemm"
|
||||||
|
FLEX_GEMM_HASHMAP_RATIO = 2.0
|
||||||
|
|
||||||
# TODO post processing
|
# TODO post processing
|
||||||
def simplify(self, target_num_faces: int, verbose: bool=False, options: dict={}):
|
def simplify(self, target_num_faces: int, verbose: bool=False, options: dict={}):
|
||||||
@ -1131,6 +1339,7 @@ def flexible_dual_grid_to_mesh(
|
|||||||
|
|
||||||
class Vae(nn.Module):
|
class Vae(nn.Module):
|
||||||
def __init__(self, config, operations=None):
|
def __init__(self, config, operations=None):
|
||||||
|
super().__init__()
|
||||||
operations = operations or torch.nn
|
operations = operations or torch.nn
|
||||||
|
|
||||||
self.txt_dec = SparseUnetVaeDecoder(
|
self.txt_dec = SparseUnetVaeDecoder(
|
||||||
@ -1139,7 +1348,8 @@ class Vae(nn.Module):
|
|||||||
latent_channels=32,
|
latent_channels=32,
|
||||||
num_blocks=[4, 16, 8, 4, 0],
|
num_blocks=[4, 16, 8, 4, 0],
|
||||||
block_type=["SparseConvNeXtBlock3d"] * 5,
|
block_type=["SparseConvNeXtBlock3d"] * 5,
|
||||||
up_block_type=["SparseResBlockS2C3d"] * 4,
|
up_block_type=["SparseResBlockC2S3d"] * 4,
|
||||||
|
block_args=[{}, {}, {}, {}, {}],
|
||||||
pred_subdiv=False
|
pred_subdiv=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1149,7 +1359,8 @@ class Vae(nn.Module):
|
|||||||
latent_channels=32,
|
latent_channels=32,
|
||||||
num_blocks=[4, 16, 8, 4, 0],
|
num_blocks=[4, 16, 8, 4, 0],
|
||||||
block_type=["SparseConvNeXtBlock3d"] * 5,
|
block_type=["SparseConvNeXtBlock3d"] * 5,
|
||||||
up_block_type=["SparseResBlockS2C3d"] * 4,
|
up_block_type=["SparseResBlockC2S3d"] * 4,
|
||||||
|
block_args=[{}, {}, {}, {}, {}],
|
||||||
)
|
)
|
||||||
|
|
||||||
def decode_shape_slat(self, slat, resolution: int):
|
def decode_shape_slat(self, slat, resolution: int):
|
||||||
|
|||||||
@ -50,6 +50,7 @@ import comfy.ldm.omnigen.omnigen2
|
|||||||
import comfy.ldm.qwen_image.model
|
import comfy.ldm.qwen_image.model
|
||||||
import comfy.ldm.kandinsky5.model
|
import comfy.ldm.kandinsky5.model
|
||||||
import comfy.ldm.anima.model
|
import comfy.ldm.anima.model
|
||||||
|
import comfy.ldm.trellis2.model
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@ -1455,6 +1456,13 @@ class WAN22(WAN21):
|
|||||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
return latent_image
|
return latent_image
|
||||||
|
|
||||||
|
class Trellis2(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None, unet_model=comfy.ldm.trellis2.model.Trellis2):
|
||||||
|
super().__init__(model_config, model_type, device, unet_model)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
return super().extra_conds(**kwargs)
|
||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||||
|
|||||||
@ -1242,6 +1242,17 @@ class WAN22_T2V(WAN21_T2V):
|
|||||||
out = model_base.WAN22(self, image_to_video=True, device=device)
|
out = model_base.WAN22(self, image_to_video=True, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class Trellis2(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "trellis2"
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.Trellis2
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
return model_base.Trellis2(self, device=device)
|
||||||
|
|
||||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@ -1596,6 +1607,6 @@ class Kandinsky5Image(Kandinsky5):
|
|||||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, Trellis2]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user