mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
model init working
This commit is contained in:
parent
f76e3a11b5
commit
d6573fd26d
@ -746,6 +746,8 @@ class Hunyuan3Dv2_1(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
class Trellis2(LatentFormat): # TODO
|
||||
latent_channels = 32
|
||||
class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
@ -1,8 +1,192 @@
|
||||
# will contain every cuda -> pytorch operation
|
||||
|
||||
import math
|
||||
import torch
|
||||
from typing import Dict
|
||||
from typing import Dict, Callable
|
||||
|
||||
NO_TRITION = False
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
heuristics = {
|
||||
'valid_kernel': lambda args: args['valid_kernel'](args['B1']),
|
||||
'valid_kernel_seg': lambda args: args['valid_kernel_seg'](args['B1']),
|
||||
}
|
||||
|
||||
#@triton_autotune(
|
||||
# configs=config.autotune_config,
|
||||
# key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'],
|
||||
#)
|
||||
@triton.heuristics(heuristics)
|
||||
@triton.jit
|
||||
def sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
neighbor,
|
||||
sorted_idx,
|
||||
output,
|
||||
# Tensor dimensions
|
||||
N, LOGN, Ci, Co, V: tl.constexpr,
|
||||
# Meta-parameters
|
||||
B1: tl.constexpr, # Block size for N dimension
|
||||
B2: tl.constexpr, # Block size for Co dimension
|
||||
BK: tl.constexpr, # Block size for K dimension (V * Ci)
|
||||
allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
|
||||
# Huristic parameters
|
||||
valid_kernel,
|
||||
valid_kernel_seg,
|
||||
):
|
||||
|
||||
block_id = tl.program_id(axis=0)
|
||||
block_dim_co = tl.cdiv(Co, B2)
|
||||
block_id_co = block_id % block_dim_co
|
||||
block_id_n = block_id // block_dim_co
|
||||
|
||||
# Create pointers for submatrices of A and B.
|
||||
num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension
|
||||
valid_kernel_start = tl.load(valid_kernel_seg + block_id_n)
|
||||
valid_kernel_seglen = tl.load(valid_kernel_seg + block_id_n + 1) - valid_kernel_start
|
||||
offset_n = block_id_n * B1 + tl.arange(0, B1)
|
||||
n_mask = offset_n < N
|
||||
offset_sorted_n = tl.load(sorted_idx + offset_n, mask=n_mask, other=0) # (B1,)
|
||||
offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,)
|
||||
offset_k = tl.arange(0, BK) # (BK,)
|
||||
|
||||
# Create a block of the output matrix C.
|
||||
accumulator = tl.zeros((B1, B2), dtype=tl.float32)
|
||||
|
||||
# Iterate along V*Ci dimension.
|
||||
for k in range(num_k * valid_kernel_seglen):
|
||||
v = k // num_k
|
||||
bk = k % num_k
|
||||
v = tl.load(valid_kernel + valid_kernel_start + v)
|
||||
# Calculate pointers to input matrix.
|
||||
neighbor_offset_n = tl.load(neighbor + offset_sorted_n * V + v) # (B1,)
|
||||
input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK)
|
||||
# Calculate pointers to weight matrix.
|
||||
weight_ptr = weight + v * Ci + bk * BK + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2)
|
||||
# Load the next block of input and weight.
|
||||
neigh_mask = neighbor_offset_n != 0xffffffff
|
||||
k_mask = offset_k < Ci - bk * BK
|
||||
input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
|
||||
weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
|
||||
# Accumulate along the K dimension.
|
||||
accumulator = tl.dot(input_block, weight_block, accumulator,
|
||||
input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
|
||||
c = accumulator.to(input.type.element_ty)
|
||||
|
||||
# add bias
|
||||
if bias is not None:
|
||||
bias_block = tl.load(bias + offset_co)
|
||||
c += bias_block[None, :]
|
||||
|
||||
# Write back the block of the output matrix with masks.
|
||||
out_offset_n = offset_sorted_n
|
||||
out_offset_co = block_id_co * B2 + tl.arange(0, B2)
|
||||
out_ptr = output + (out_offset_n[:, None] * Co + out_offset_co[None, :])
|
||||
out_mask = n_mask[:, None] & (out_offset_co[None, :] < Co)
|
||||
tl.store(out_ptr, c, mask=out_mask)
|
||||
def sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
neighbor: torch.Tensor,
|
||||
sorted_idx: torch.Tensor,
|
||||
valid_kernel: Callable[[int], torch.Tensor],
|
||||
valid_kernel_seg: Callable[[int], torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
|
||||
LOGN = int(math.log2(N))
|
||||
output = torch.empty((N, Co), device=input.device, dtype=input.dtype)
|
||||
grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),)
|
||||
sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid](
|
||||
input, weight, bias, neighbor, sorted_idx, output,
|
||||
N, LOGN, Ci, Co, V, #
|
||||
valid_kernel=valid_kernel,
|
||||
valid_kernel_seg=valid_kernel_seg,
|
||||
allow_tf32=torch.cuda.is_tf32_supported(),
|
||||
)
|
||||
return output
|
||||
except:
|
||||
NO_TRITION = True
|
||||
|
||||
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
|
||||
# offsets in same order as CUDA kernel
|
||||
offsets = []
|
||||
for vx in range(Kw):
|
||||
for vy in range(Kh):
|
||||
for vz in range(Kd):
|
||||
offsets.append((
|
||||
vx * Dw,
|
||||
vy * Dh,
|
||||
vz * Dd
|
||||
))
|
||||
return torch.tensor(offsets, device=device)
|
||||
|
||||
def build_submanifold_neighbor_map(
|
||||
hashmap,
|
||||
coords: torch.Tensor,
|
||||
W, H, D,
|
||||
Kw, Kh, Kd,
|
||||
Dw, Dh, Dd,
|
||||
):
|
||||
device = coords.device
|
||||
M = coords.shape[0]
|
||||
V = Kw * Kh * Kd
|
||||
half_V = V // 2 + 1
|
||||
|
||||
INVALID = hashmap.default_value
|
||||
|
||||
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long)
|
||||
|
||||
b = coords[:, 0]
|
||||
x = coords[:, 1]
|
||||
y = coords[:, 2]
|
||||
z = coords[:, 3]
|
||||
|
||||
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device)
|
||||
|
||||
ox = x[:, None] - (Kw // 2) * Dw
|
||||
oy = y[:, None] - (Kh // 2) * Dh
|
||||
oz = z[:, None] - (Kd // 2) * Dd
|
||||
|
||||
for v in range(half_V):
|
||||
if v == half_V - 1:
|
||||
neighbor[:, v] = torch.arange(M, device=device)
|
||||
continue
|
||||
|
||||
dx, dy, dz = offsets[v]
|
||||
|
||||
kx = ox[:, v] + dx
|
||||
ky = oy[:, v] + dy
|
||||
kz = oz[:, v] + dz
|
||||
|
||||
valid = (
|
||||
(kx >= 0) & (kx < W) &
|
||||
(ky >= 0) & (ky < H) &
|
||||
(kz >= 0) & (kz < D)
|
||||
)
|
||||
|
||||
flat = (
|
||||
b * (W * H * D) +
|
||||
kx * (H * D) +
|
||||
ky * D +
|
||||
kz
|
||||
)
|
||||
|
||||
flat = flat[valid]
|
||||
idx = torch.nonzero(valid, as_tuple=False).squeeze(1)
|
||||
|
||||
found = hashmap.lookup_flat(flat)
|
||||
|
||||
neighbor[idx, v] = found
|
||||
|
||||
# symmetric write
|
||||
valid_found = found != INVALID
|
||||
neighbor[found[valid_found], V - 1 - v] = idx[valid_found]
|
||||
|
||||
return neighbor
|
||||
|
||||
class TorchHashMap:
|
||||
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
|
||||
@ -22,6 +206,207 @@ class TorchHashMap:
|
||||
out[found] = self.sorted_vals[idx[found]]
|
||||
return out
|
||||
|
||||
|
||||
UINT32_SENTINEL = 0xFFFFFFFF
|
||||
|
||||
def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map):
|
||||
device = neighbor_map.device
|
||||
N, V = neighbor_map.shape
|
||||
|
||||
|
||||
neigh = neighbor_map.to(torch.long)
|
||||
sentinel = torch.tensor(UINT32_SENTINEL, dtype=torch.long, device=device)
|
||||
|
||||
|
||||
neigh_map_T = neigh.t().reshape(-1)
|
||||
|
||||
neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32)
|
||||
|
||||
mask = (neigh != sentinel).to(torch.long)
|
||||
|
||||
powers = (1 << torch.arange(V, dtype=torch.long, device=device))
|
||||
|
||||
gray_long = (mask * powers).sum(dim=1)
|
||||
|
||||
gray_code = gray_long.to(torch.int32)
|
||||
|
||||
binary_long = gray_long.clone()
|
||||
for v in range(1, V):
|
||||
binary_long ^= (gray_long >> v)
|
||||
binary_code = binary_long.to(torch.int32)
|
||||
|
||||
sorted_idx = torch.argsort(binary_code)
|
||||
|
||||
prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T.to(torch.int32), dim=0) # (V*N,)
|
||||
|
||||
total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0
|
||||
|
||||
if total_valid_signal > 0:
|
||||
valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
|
||||
valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
|
||||
|
||||
pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0]
|
||||
|
||||
to = (prefix_sum_neighbor_mask[pos] - 1).to(torch.long)
|
||||
|
||||
valid_signal_i[to] = (pos % N).to(torch.long)
|
||||
|
||||
valid_signal_o[to] = neigh_map_T[pos].to(torch.long)
|
||||
else:
|
||||
valid_signal_i = torch.empty((0,), dtype=torch.long, device=device)
|
||||
valid_signal_o = torch.empty((0,), dtype=torch.long, device=device)
|
||||
|
||||
seg = torch.empty((V + 1,), dtype=torch.long, device=device)
|
||||
seg[0] = 0
|
||||
if V > 0:
|
||||
idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1
|
||||
seg[1:] = prefix_sum_neighbor_mask[idxs].to(torch.long)
|
||||
else:
|
||||
pass
|
||||
|
||||
return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg
|
||||
|
||||
def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
x = x.to(torch.int64)
|
||||
|
||||
m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device)
|
||||
m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device)
|
||||
m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device)
|
||||
h01 = torch.tensor(0x0101010101010101, dtype=torch.int64, device=x.device)
|
||||
|
||||
x = x - ((x >> 1) & m1)
|
||||
x = (x & m2) + ((x >> 2) & m2)
|
||||
x = (x + (x >> 4)) & m4
|
||||
x = (x * h01) >> 56
|
||||
return x.to(torch.int32)
|
||||
|
||||
|
||||
def neighbor_map_post_process_for_masked_implicit_gemm_2(
|
||||
gray_code: torch.Tensor, # [N], int32-like (non-negative)
|
||||
sorted_idx: torch.Tensor, # [N], long (indexing into gray_code)
|
||||
block_size: int
|
||||
):
|
||||
device = gray_code.device
|
||||
N = gray_code.numel()
|
||||
|
||||
# num of blocks (same as CUDA)
|
||||
num_blocks = (N + block_size - 1) // block_size
|
||||
|
||||
# Ensure dtypes
|
||||
gray_long = gray_code.to(torch.int64) # safer to OR in 64-bit then cast
|
||||
sorted_idx = sorted_idx.to(torch.long)
|
||||
|
||||
# 1) Group gray_code by blocks and compute OR across each block
|
||||
# pad the last block with zeros if necessary so we can reshape
|
||||
pad = num_blocks * block_size - N
|
||||
if pad > 0:
|
||||
pad_vals = torch.zeros((pad,), dtype=torch.int64, device=device)
|
||||
gray_padded = torch.cat([gray_long[sorted_idx], pad_vals], dim=0)
|
||||
else:
|
||||
gray_padded = gray_long[sorted_idx]
|
||||
|
||||
# reshape to (num_blocks, block_size) and compute bitwise_or across dim=1
|
||||
gray_blocks = gray_padded.view(num_blocks, block_size) # each row = block entries
|
||||
# reduce with bitwise_or
|
||||
reduced_code = gray_blocks[:, 0].clone()
|
||||
for i in range(1, block_size):
|
||||
reduced_code |= gray_blocks[:, i]
|
||||
reduced_code = reduced_code.to(torch.int32) # match CUDA int32
|
||||
|
||||
# 2) compute seglen (popcount per reduced_code) and seg (prefix sum)
|
||||
seglen_counts = _popcount_int32_tensor(reduced_code.to(torch.int64)).to(torch.int32) # [num_blocks]
|
||||
# seg: length num_blocks+1, seg[0]=0, seg[i+1]=cumsum(seglen_counts) up to i
|
||||
seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device)
|
||||
seg[0] = 0
|
||||
if num_blocks > 0:
|
||||
seg[1:] = torch.cumsum(seglen_counts, dim=0)
|
||||
|
||||
total = int(seg[-1].item())
|
||||
|
||||
# 3) scatter — produce valid_kernel_idx as concatenated ascending set-bit positions for each reduced_code row
|
||||
if total == 0:
|
||||
valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device)
|
||||
return valid_kernel_idx, seg
|
||||
|
||||
max_val = int(reduced_code.max().item())
|
||||
V = max_val.bit_length() if max_val > 0 else 0
|
||||
# If you know V externally, pass it instead or set here explicitly.
|
||||
|
||||
if V == 0:
|
||||
# no bits set anywhere
|
||||
valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device)
|
||||
return valid_kernel_idx, seg
|
||||
|
||||
# build mask of shape (num_blocks, V): True where bit is set
|
||||
bit_pos = torch.arange(0, V, dtype=torch.int64, device=device) # [V]
|
||||
# shifted = reduced_code[:, None] >> bit_pos[None, :]
|
||||
shifted = reduced_code.to(torch.int64).unsqueeze(1) >> bit_pos.unsqueeze(0)
|
||||
bits = (shifted & 1).to(torch.bool) # (num_blocks, V)
|
||||
|
||||
positions = bit_pos.unsqueeze(0).expand(num_blocks, V)
|
||||
|
||||
valid_positions = positions[bits]
|
||||
valid_kernel_idx = valid_positions.to(torch.int32).contiguous()
|
||||
|
||||
return valid_kernel_idx, seg
|
||||
|
||||
|
||||
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
|
||||
if len(shape) == 5:
|
||||
N, C, W, H, D = shape
|
||||
else:
|
||||
W, H, D = shape
|
||||
|
||||
Co, Kw, Kh, Kd, Ci = weight.shape
|
||||
|
||||
b_stride = W * H * D
|
||||
x_stride = H * D
|
||||
y_stride = D
|
||||
z_stride = 1
|
||||
|
||||
flat_keys = (coords[:, 0].long() * b_stride +
|
||||
coords[:, 1].long() * x_stride +
|
||||
coords[:, 2].long() * y_stride +
|
||||
coords[:, 3].long() * z_stride)
|
||||
|
||||
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=coords.device)
|
||||
|
||||
hashmap = TorchHashMap(flat_keys, vals, 0xFFFFFFFF)
|
||||
|
||||
if neighbor_cache is None:
|
||||
neighbor = build_submanifold_neighbor_map(
|
||||
hashmap, coords, W, H, D, Kw, Kh, Kd,
|
||||
dilation[0], dilation[1], dilation[2]
|
||||
)
|
||||
else:
|
||||
neighbor = neighbor_cache
|
||||
|
||||
block_size = 128
|
||||
|
||||
gray_code, sorted_idx, valid_signal_i, valid_signal_o, valid_signal_seg = \
|
||||
neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor)
|
||||
|
||||
valid_kernel, valid_kernel_seg = \
|
||||
neighbor_map_post_process_for_masked_implicit_gemm_2(gray_code, sorted_idx, block_size)
|
||||
|
||||
valid_kernel_fn = lambda b_size: valid_kernel
|
||||
valid_kernel_seg_fn = lambda b_size: valid_kernel_seg
|
||||
|
||||
weight_flat = weight.contiguous().view(Co, -1, Ci)
|
||||
|
||||
out = sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk(
|
||||
feats,
|
||||
weight_flat,
|
||||
bias,
|
||||
neighbor,
|
||||
sorted_idx,
|
||||
valid_kernel_fn,
|
||||
valid_kernel_seg_fn
|
||||
)
|
||||
|
||||
return out, neighbor
|
||||
|
||||
class Voxel:
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -408,7 +408,7 @@ class SLatFlowModel(nn.Module):
|
||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||
self.dtype = dtype
|
||||
|
||||
self.t_embedder = TimestepEmbedder(model_channels)
|
||||
self.t_embedder = TimestepEmbedder(model_channels, device=device, dtype=dtype, operations=operations)
|
||||
if share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
@ -485,15 +485,25 @@ class Trellis2(nn.Module):
|
||||
qk_rms_norm = True,
|
||||
qk_rms_norm_cross = True,
|
||||
dtype=None, device=None, operations=None):
|
||||
|
||||
super().__init__()
|
||||
args = {
|
||||
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
|
||||
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
|
||||
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
|
||||
}
|
||||
# TODO: update the names/checkpoints
|
||||
self.img2shape = SLatFlowModel(resolution, in_channels=in_channels, *args)
|
||||
self.shape2txt = SLatFlowModel(resolution, in_channels=in_channels*2, *args)
|
||||
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
|
||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
||||
self.shape_generation = True
|
||||
|
||||
def forward(self, x, timestep, context):
|
||||
pass
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
# TODO add mode
|
||||
mode = kwargs.get("mode", "shape_generation")
|
||||
mode = "texture_generation" if mode == 1 else "shape_generation"
|
||||
if mode == "shape_generation":
|
||||
out = self.img2shape(x, timestep, context)
|
||||
if mode == "texture_generation":
|
||||
out = self.shape2txt(x, timestep, context)
|
||||
|
||||
return out
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List, Any, Dict, Optional, overload, Union, Tuple
|
||||
@ -5,12 +6,219 @@ from fractions import Fraction
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from cumesh import TorchHashMap, Mesh, MeshWithVoxel
|
||||
from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
|
||||
|
||||
|
||||
class SparseConv3d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
|
||||
super(SparseConv3d, self).__init__()
|
||||
sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key)
|
||||
|
||||
def forward(self, x):
|
||||
return sparse_conv3d_forward(self, x)
|
||||
|
||||
|
||||
def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size, ) * 3
|
||||
self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, ) * 3
|
||||
self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation, ) * 3
|
||||
|
||||
self.weight = nn.Parameter(torch.empty((out_channels, in_channels, *self.kernel_size)))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(out_channels))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
if self.bias is not None:
|
||||
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
if fan_in != 0:
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
torch.nn.init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
# Permute weight (Co, Ci, Kd, Kh, Kw) -> (Co, Kd, Kh, Kw, Ci)
|
||||
self.weight = nn.Parameter(self.weight.permute(0, 2, 3, 4, 1).contiguous())
|
||||
|
||||
|
||||
def sparse_conv3d_forward(self, x):
|
||||
# check if neighbor map is already computed
|
||||
Co, Kd, Kh, Kw, Ci = self.weight.shape
|
||||
neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}'
|
||||
neighbor_cache = x.get_spatial_cache(neighbor_cache_key)
|
||||
|
||||
out, neighbor_cache_ = sparse_submanifold_conv3d(
|
||||
x.feats,
|
||||
x.coords,
|
||||
torch.Size([*x.shape, *x.spatial_shape]),
|
||||
self.weight,
|
||||
self.bias,
|
||||
neighbor_cache,
|
||||
self.dilation
|
||||
)
|
||||
|
||||
if neighbor_cache is None:
|
||||
x.register_spatial_cache(neighbor_cache_key, neighbor_cache_)
|
||||
|
||||
out = x.replace(out)
|
||||
return out
|
||||
|
||||
class LayerNorm32(nn.LayerNorm):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
o = super().forward(x)
|
||||
return o.to(x_dtype)
|
||||
|
||||
class SparseConvNeXtBlock3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
use_checkpoint: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
||||
self.conv = SparseConv3d(channels, channels, 3)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(channels, int(channels * mlp_ratio)),
|
||||
nn.SiLU(),
|
||||
nn.Linear(int(channels * mlp_ratio), channels),
|
||||
)
|
||||
|
||||
def _forward(self, x):
|
||||
h = self.conv(x)
|
||||
h = h.replace(self.norm(h.feats))
|
||||
h = h.replace(self.mlp(h.feats))
|
||||
return h + x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward(x)
|
||||
|
||||
class SparseSpatial2Channel(nn.Module):
|
||||
def __init__(self, factor: int = 2):
|
||||
super(SparseSpatial2Channel, self).__init__()
|
||||
self.factor = factor
|
||||
|
||||
def forward(self, x):
|
||||
DIM = x.coords.shape[-1] - 1
|
||||
cache = x.get_spatial_cache(f'spatial2channel_{self.factor}')
|
||||
if cache is None:
|
||||
coord = list(x.coords.unbind(dim=-1))
|
||||
for i in range(DIM):
|
||||
coord[i+1] = coord[i+1] // self.factor
|
||||
subidx = x.coords[:, 1:] % self.factor
|
||||
subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)])
|
||||
|
||||
MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape]
|
||||
OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
|
||||
code = sum([c * o for c, o in zip(coord, OFFSET)])
|
||||
code, idx = code.unique(return_inverse=True)
|
||||
|
||||
new_coords = torch.stack(
|
||||
[code // OFFSET[0]] +
|
||||
[(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
|
||||
dim=-1
|
||||
)
|
||||
else:
|
||||
new_coords, idx, subidx = cache
|
||||
|
||||
new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype)
|
||||
new_feats[idx * self.factor ** DIM + subidx] = x.feats
|
||||
|
||||
out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM]))
|
||||
out._scale = tuple([s * self.factor for s in x._scale])
|
||||
out._spatial_cache = x._spatial_cache
|
||||
|
||||
if cache is None:
|
||||
x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx))
|
||||
out.register_spatial_cache(f'channel2spatial_{self.factor}', (x.coords, idx, subidx))
|
||||
out.register_spatial_cache('shape', torch.Size(MAX))
|
||||
|
||||
return out
|
||||
|
||||
class SparseChannel2Spatial(nn.Module):
|
||||
def __init__(self, factor: int = 2):
|
||||
super(SparseChannel2Spatial, self).__init__()
|
||||
self.factor = factor
|
||||
|
||||
def forward(self, x, subdivision = None):
|
||||
DIM = x.coords.shape[-1] - 1
|
||||
|
||||
cache = x.get_spatial_cache(f'channel2spatial_{self.factor}')
|
||||
if cache is None:
|
||||
if subdivision is None:
|
||||
raise ValueError('Cache not found. Provide subdivision tensor or pair SparseChannel2Spatial with SparseSpatial2Channel.')
|
||||
else:
|
||||
sub = subdivision.feats # [N, self.factor ** DIM]
|
||||
N_leaf = sub.sum(dim=-1) # [N]
|
||||
subidx = sub.nonzero()[:, -1]
|
||||
new_coords = x.coords.clone().detach()
|
||||
new_coords[:, 1:] *= self.factor
|
||||
new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0])
|
||||
for i in range(DIM):
|
||||
new_coords[:, i+1] += subidx // self.factor ** i % self.factor
|
||||
idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0])
|
||||
else:
|
||||
new_coords, idx, subidx = cache
|
||||
|
||||
x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1)
|
||||
new_feats = x_feats[idx * self.factor ** DIM + subidx]
|
||||
out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM]))
|
||||
out._scale = tuple([s / self.factor for s in x._scale])
|
||||
if cache is not None: # only keep cache when subdiv following it
|
||||
out._spatial_cache = x._spatial_cache
|
||||
return out
|
||||
|
||||
class SparseResBlockC2S3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
use_checkpoint: bool = False,
|
||||
pred_subdiv: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.pred_subdiv = pred_subdiv
|
||||
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
||||
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
|
||||
self.conv1 = SparseConv3d(channels, self.out_channels * 8, 3)
|
||||
self.conv2 = SparseConv3d(self.out_channels, self.out_channels, 3)
|
||||
self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1))
|
||||
if pred_subdiv:
|
||||
self.to_subdiv = SparseLinear(channels, 8)
|
||||
self.updown = SparseChannel2Spatial(2)
|
||||
|
||||
def _forward(self, x, subdiv = None):
|
||||
if self.pred_subdiv:
|
||||
subdiv = self.to_subdiv(x)
|
||||
h = x.replace(self.norm1(x.feats))
|
||||
h = h.replace(F.silu(h.feats))
|
||||
h = self.conv1(h)
|
||||
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
|
||||
h = self.updown(h, subdiv_binarized)
|
||||
x = self.updown(x, subdiv_binarized)
|
||||
h = h.replace(self.norm2(h.feats))
|
||||
h = h.replace(F.silu(h.feats))
|
||||
h = self.conv2(h)
|
||||
h = h + self.skip_connection(x)
|
||||
if self.pred_subdiv:
|
||||
return h, subdiv
|
||||
else:
|
||||
return h
|
||||
|
||||
# TODO: determine which conv they actually use
|
||||
@dataclass
|
||||
class config:
|
||||
CONV = "none"
|
||||
CONV = "flexgemm"
|
||||
FLEX_GEMM_HASHMAP_RATIO = 2.0
|
||||
|
||||
# TODO post processing
|
||||
def simplify(self, target_num_faces: int, verbose: bool=False, options: dict={}):
|
||||
@ -1131,6 +1339,7 @@ def flexible_dual_grid_to_mesh(
|
||||
|
||||
class Vae(nn.Module):
|
||||
def __init__(self, config, operations=None):
|
||||
super().__init__()
|
||||
operations = operations or torch.nn
|
||||
|
||||
self.txt_dec = SparseUnetVaeDecoder(
|
||||
@ -1139,7 +1348,8 @@ class Vae(nn.Module):
|
||||
latent_channels=32,
|
||||
num_blocks=[4, 16, 8, 4, 0],
|
||||
block_type=["SparseConvNeXtBlock3d"] * 5,
|
||||
up_block_type=["SparseResBlockS2C3d"] * 4,
|
||||
up_block_type=["SparseResBlockC2S3d"] * 4,
|
||||
block_args=[{}, {}, {}, {}, {}],
|
||||
pred_subdiv=False
|
||||
)
|
||||
|
||||
@ -1149,7 +1359,8 @@ class Vae(nn.Module):
|
||||
latent_channels=32,
|
||||
num_blocks=[4, 16, 8, 4, 0],
|
||||
block_type=["SparseConvNeXtBlock3d"] * 5,
|
||||
up_block_type=["SparseResBlockS2C3d"] * 4,
|
||||
up_block_type=["SparseResBlockC2S3d"] * 4,
|
||||
block_args=[{}, {}, {}, {}, {}],
|
||||
)
|
||||
|
||||
def decode_shape_slat(self, slat, resolution: int):
|
||||
|
||||
@ -50,6 +50,7 @@ import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.trellis2.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -1455,6 +1456,13 @@ class WAN22(WAN21):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class Trellis2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None, unet_model=comfy.ldm.trellis2.model.Trellis2):
|
||||
super().__init__(model_config, model_type, device, unet_model)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
return super().extra_conds(**kwargs)
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||
|
||||
@ -1242,6 +1242,17 @@ class WAN22_T2V(WAN21_T2V):
|
||||
out = model_base.WAN22(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class Trellis2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "trellis2"
|
||||
}
|
||||
|
||||
latent_format = latent_formats.Trellis2
|
||||
vae_key_prefix = ["vae."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.Trellis2(self, device=device)
|
||||
|
||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2",
|
||||
@ -1596,6 +1607,6 @@ class Kandinsky5Image(Kandinsky5):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, Trellis2]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user