mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-24 09:22:32 +08:00
removed unnecessary code + optimizations + progres
This commit is contained in:
parent
f31c2e1d1d
commit
39270fdca9
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from typing import Dict, Callable
|
from typing import Callable
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
NO_TRITON = False
|
NO_TRITON = False
|
||||||
@ -201,13 +201,13 @@ class TorchHashMap:
|
|||||||
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
|
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
|
||||||
device = keys.device
|
device = keys.device
|
||||||
# use long for searchsorted
|
# use long for searchsorted
|
||||||
self.sorted_keys, order = torch.sort(keys.long())
|
self.sorted_keys, order = torch.sort(keys.to(torch.long))
|
||||||
self.sorted_vals = values.long()[order]
|
self.sorted_vals = values.to(torch.long)[order]
|
||||||
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
|
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
|
||||||
self._n = self.sorted_keys.numel()
|
self._n = self.sorted_keys.numel()
|
||||||
|
|
||||||
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
|
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
|
||||||
flat = flat_keys.long()
|
flat = flat_keys.to(torch.long)
|
||||||
if self._n == 0:
|
if self._n == 0:
|
||||||
return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
|
return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
|
||||||
idx = torch.searchsorted(self.sorted_keys, flat)
|
idx = torch.searchsorted(self.sorted_keys, flat)
|
||||||
@ -225,44 +225,35 @@ def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map):
|
|||||||
device = neighbor_map.device
|
device = neighbor_map.device
|
||||||
N, V = neighbor_map.shape
|
N, V = neighbor_map.shape
|
||||||
|
|
||||||
|
sentinel = UINT32_SENTINEL
|
||||||
|
|
||||||
neigh = neighbor_map.to(torch.long)
|
neigh_map_T = neighbor_map.t().reshape(-1)
|
||||||
sentinel = torch.tensor(UINT32_SENTINEL, dtype=torch.long, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
neigh_map_T = neigh.t().reshape(-1)
|
|
||||||
|
|
||||||
neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32)
|
neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32)
|
||||||
|
|
||||||
mask = (neigh != sentinel).to(torch.long)
|
mask = (neighbor_map != sentinel).to(torch.long)
|
||||||
|
gray_code = torch.zeros(N, dtype=torch.long, device=device)
|
||||||
|
|
||||||
powers = (1 << torch.arange(V, dtype=torch.long, device=device))
|
for v in range(V):
|
||||||
|
gray_code |= (mask[:, v] << v)
|
||||||
|
|
||||||
gray_long = (mask * powers).sum(dim=1)
|
binary_code = gray_code.clone()
|
||||||
|
|
||||||
gray_code = gray_long.to(torch.int32)
|
|
||||||
|
|
||||||
binary_long = gray_long.clone()
|
|
||||||
for v in range(1, V):
|
for v in range(1, V):
|
||||||
binary_long ^= (gray_long >> v)
|
binary_code ^= (gray_code >> v)
|
||||||
binary_code = binary_long.to(torch.int32)
|
|
||||||
|
|
||||||
sorted_idx = torch.argsort(binary_code)
|
sorted_idx = torch.argsort(binary_code)
|
||||||
|
|
||||||
prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T.to(torch.int32), dim=0) # (V*N,)
|
prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T, dim=0)
|
||||||
|
|
||||||
total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0
|
total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0
|
||||||
|
|
||||||
if total_valid_signal > 0:
|
if total_valid_signal > 0:
|
||||||
|
pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0]
|
||||||
|
to = (prefix_sum_neighbor_mask[pos] - 1).long()
|
||||||
|
|
||||||
valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
|
valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
|
||||||
valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
|
valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
|
||||||
|
|
||||||
pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0]
|
|
||||||
|
|
||||||
to = (prefix_sum_neighbor_mask[pos] - 1).to(torch.long)
|
|
||||||
|
|
||||||
valid_signal_i[to] = (pos % N).to(torch.long)
|
valid_signal_i[to] = (pos % N).to(torch.long)
|
||||||
|
|
||||||
valid_signal_o[to] = neigh_map_T[pos].to(torch.long)
|
valid_signal_o[to] = neigh_map_T[pos].to(torch.long)
|
||||||
else:
|
else:
|
||||||
valid_signal_i = torch.empty((0,), dtype=torch.long, device=device)
|
valid_signal_i = torch.empty((0,), dtype=torch.long, device=device)
|
||||||
@ -272,9 +263,7 @@ def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map):
|
|||||||
seg[0] = 0
|
seg[0] = 0
|
||||||
if V > 0:
|
if V > 0:
|
||||||
idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1
|
idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1
|
||||||
seg[1:] = prefix_sum_neighbor_mask[idxs].to(torch.long)
|
seg[1:] = prefix_sum_neighbor_mask[idxs]
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg
|
return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg
|
||||||
|
|
||||||
@ -295,40 +284,41 @@ def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
def neighbor_map_post_process_for_masked_implicit_gemm_2(
|
def neighbor_map_post_process_for_masked_implicit_gemm_2(
|
||||||
gray_code: torch.Tensor, # [N], int32-like (non-negative)
|
gray_code: torch.Tensor,
|
||||||
sorted_idx: torch.Tensor, # [N], long (indexing into gray_code)
|
sorted_idx: torch.Tensor,
|
||||||
block_size: int
|
block_size: int
|
||||||
):
|
):
|
||||||
device = gray_code.device
|
device = gray_code.device
|
||||||
N = gray_code.numel()
|
N = gray_code.numel()
|
||||||
|
|
||||||
# num of blocks (same as CUDA)
|
|
||||||
num_blocks = (N + block_size - 1) // block_size
|
num_blocks = (N + block_size - 1) // block_size
|
||||||
|
|
||||||
# Ensure dtypes
|
|
||||||
gray_long = gray_code.to(torch.int64) # safer to OR in 64-bit then cast
|
|
||||||
sorted_idx = sorted_idx.to(torch.long)
|
|
||||||
|
|
||||||
# 1) Group gray_code by blocks and compute OR across each block
|
|
||||||
# pad the last block with zeros if necessary so we can reshape
|
|
||||||
pad = num_blocks * block_size - N
|
pad = num_blocks * block_size - N
|
||||||
if pad > 0:
|
if pad > 0:
|
||||||
pad_vals = torch.zeros((pad,), dtype=torch.int64, device=device)
|
pad_vals = torch.zeros((pad,), dtype=torch.int32, device=device)
|
||||||
gray_padded = torch.cat([gray_long[sorted_idx], pad_vals], dim=0)
|
gray_padded = torch.cat([gray_code[sorted_idx], pad_vals], dim=0)
|
||||||
else:
|
else:
|
||||||
gray_padded = gray_long[sorted_idx]
|
gray_padded = gray_code[sorted_idx]
|
||||||
|
|
||||||
# reshape to (num_blocks, block_size) and compute bitwise_or across dim=1
|
gray_blocks = gray_padded.view(num_blocks, block_size)
|
||||||
gray_blocks = gray_padded.view(num_blocks, block_size) # each row = block entries
|
|
||||||
# reduce with bitwise_or
|
reduced_code = gray_blocks
|
||||||
reduced_code = gray_blocks[:, 0].clone()
|
while reduced_code.shape[1] > 1:
|
||||||
for i in range(1, block_size):
|
half = reduced_code.shape[1] // 2
|
||||||
reduced_code |= gray_blocks[:, i]
|
remainder = reduced_code.shape[1] % 2
|
||||||
reduced_code = reduced_code.to(torch.int32) # match CUDA int32
|
|
||||||
|
left = reduced_code[:, :half * 2:2]
|
||||||
|
right = reduced_code[:, 1:half * 2:2]
|
||||||
|
merged = left | right
|
||||||
|
|
||||||
|
if remainder:
|
||||||
|
reduced_code = torch.cat([merged, reduced_code[:, -1:]], dim=1)
|
||||||
|
else:
|
||||||
|
reduced_code = merged
|
||||||
|
|
||||||
|
reduced_code = reduced_code.squeeze(1)
|
||||||
|
|
||||||
|
seglen_counts = _popcount_int32_tensor(reduced_code).to(torch.int32)
|
||||||
|
|
||||||
# 2) compute seglen (popcount per reduced_code) and seg (prefix sum)
|
|
||||||
seglen_counts = _popcount_int32_tensor(reduced_code.to(torch.int64)).to(torch.int32) # [num_blocks]
|
|
||||||
# seg: length num_blocks+1, seg[0]=0, seg[i+1]=cumsum(seglen_counts) up to i
|
|
||||||
seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device)
|
seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device)
|
||||||
seg[0] = 0
|
seg[0] = 0
|
||||||
if num_blocks > 0:
|
if num_blocks > 0:
|
||||||
@ -336,30 +326,20 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2(
|
|||||||
|
|
||||||
total = int(seg[-1].item())
|
total = int(seg[-1].item())
|
||||||
|
|
||||||
# 3) scatter — produce valid_kernel_idx as concatenated ascending set-bit positions for each reduced_code row
|
|
||||||
if total == 0:
|
if total == 0:
|
||||||
valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device)
|
return torch.empty((0,), dtype=torch.int32, device=device), seg
|
||||||
return valid_kernel_idx, seg
|
|
||||||
|
|
||||||
max_val = int(reduced_code.max().item())
|
V = int(reduced_code.max().item()).bit_length() if reduced_code.max() > 0 else 0
|
||||||
V = max_val.bit_length() if max_val > 0 else 0
|
|
||||||
# If you know V externally, pass it instead or set here explicitly.
|
|
||||||
|
|
||||||
if V == 0:
|
if V == 0:
|
||||||
# no bits set anywhere
|
return torch.empty((0,), dtype=torch.int32, device=device), seg
|
||||||
valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device)
|
|
||||||
return valid_kernel_idx, seg
|
|
||||||
|
|
||||||
# build mask of shape (num_blocks, V): True where bit is set
|
bit_pos = torch.arange(0, V, dtype=torch.int32, device=device)
|
||||||
bit_pos = torch.arange(0, V, dtype=torch.int64, device=device) # [V]
|
shifted = reduced_code.unsqueeze(1) >> bit_pos.unsqueeze(0)
|
||||||
# shifted = reduced_code[:, None] >> bit_pos[None, :]
|
bits = (shifted & 1).to(torch.bool)
|
||||||
shifted = reduced_code.to(torch.int64).unsqueeze(1) >> bit_pos.unsqueeze(0)
|
|
||||||
bits = (shifted & 1).to(torch.bool) # (num_blocks, V)
|
|
||||||
|
|
||||||
positions = bit_pos.unsqueeze(0).expand(num_blocks, V)
|
positions = bit_pos.unsqueeze(0).expand(num_blocks, V)
|
||||||
|
valid_kernel_idx = positions[bits].to(torch.int32).contiguous()
|
||||||
valid_positions = positions[bits]
|
|
||||||
valid_kernel_idx = valid_positions.to(torch.int32).contiguous()
|
|
||||||
|
|
||||||
return valid_kernel_idx, seg
|
return valid_kernel_idx, seg
|
||||||
|
|
||||||
@ -425,35 +405,6 @@ def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache
|
|||||||
|
|
||||||
return out, neighbor
|
return out, neighbor
|
||||||
|
|
||||||
class Voxel:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
origin: list,
|
|
||||||
voxel_size: float,
|
|
||||||
coords: torch.Tensor = None,
|
|
||||||
attrs: torch.Tensor = None,
|
|
||||||
layout = None,
|
|
||||||
device: torch.device = None
|
|
||||||
):
|
|
||||||
if layout is None:
|
|
||||||
layout = {}
|
|
||||||
self.origin = torch.tensor(origin, dtype=torch.float32, device=device)
|
|
||||||
self.voxel_size = voxel_size
|
|
||||||
self.coords = coords
|
|
||||||
self.attrs = attrs
|
|
||||||
self.layout = layout
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def position(self):
|
|
||||||
return (self.coords + 0.5) * self.voxel_size + self.origin[None, :]
|
|
||||||
|
|
||||||
def split_attrs(self):
|
|
||||||
return {
|
|
||||||
k: self.attrs[:, self.layout[k]]
|
|
||||||
for k in self.layout
|
|
||||||
}
|
|
||||||
|
|
||||||
class Mesh:
|
class Mesh:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
vertices,
|
vertices,
|
||||||
@ -480,35 +431,3 @@ class Mesh:
|
|||||||
|
|
||||||
def cpu(self):
|
def cpu(self):
|
||||||
return self.to('cpu')
|
return self.to('cpu')
|
||||||
|
|
||||||
class MeshWithVoxel(Mesh, Voxel):
|
|
||||||
def __init__(self,
|
|
||||||
vertices: torch.Tensor,
|
|
||||||
faces: torch.Tensor,
|
|
||||||
origin: list,
|
|
||||||
voxel_size: float,
|
|
||||||
coords: torch.Tensor,
|
|
||||||
attrs: torch.Tensor,
|
|
||||||
voxel_shape: torch.Size,
|
|
||||||
layout: Dict = {},
|
|
||||||
):
|
|
||||||
self.vertices = vertices.float()
|
|
||||||
self.faces = faces.int()
|
|
||||||
self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device)
|
|
||||||
self.voxel_size = voxel_size
|
|
||||||
self.coords = coords
|
|
||||||
self.attrs = attrs
|
|
||||||
self.voxel_shape = voxel_shape
|
|
||||||
self.layout = layout
|
|
||||||
|
|
||||||
def to(self, device, non_blocking=False):
|
|
||||||
return MeshWithVoxel(
|
|
||||||
self.vertices.to(device, non_blocking=non_blocking),
|
|
||||||
self.faces.to(device, non_blocking=non_blocking),
|
|
||||||
self.origin.tolist(),
|
|
||||||
self.voxel_size,
|
|
||||||
self.coords.to(device, non_blocking=non_blocking),
|
|
||||||
self.attrs.to(device, non_blocking=non_blocking),
|
|
||||||
self.voxel_shape,
|
|
||||||
self.layout,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Any, Dict, Optional, overload, Union, Tuple
|
from typing import List, Any, Dict, Optional, overload, Union, Tuple
|
||||||
from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
|
from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, sparse_submanifold_conv3d
|
||||||
|
|
||||||
|
|
||||||
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
|
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
|
||||||
@ -210,6 +210,8 @@ class SparseResBlockC2S3d(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, subdiv = None):
|
def forward(self, x, subdiv = None):
|
||||||
if self.pred_subdiv:
|
if self.pred_subdiv:
|
||||||
|
dtype = next(self.to_subdiv.parameters()).dtype
|
||||||
|
x = x.to(dtype)
|
||||||
subdiv = self.to_subdiv(x)
|
subdiv = self.to_subdiv(x)
|
||||||
norm1 = self.norm1.to(torch.float32)
|
norm1 = self.norm1.to(torch.float32)
|
||||||
norm2 = self.norm2.to(torch.float32)
|
norm2 = self.norm2.to(torch.float32)
|
||||||
@ -987,114 +989,7 @@ def convert_module_to_f16(l):
|
|||||||
for p in l.parameters():
|
for p in l.parameters():
|
||||||
p.data = p.data.half()
|
p.data = p.data.half()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SparseUnetVaeEncoder(nn.Module):
|
|
||||||
"""
|
|
||||||
Sparse Swin Transformer Unet VAE model.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
model_channels: List[int],
|
|
||||||
latent_channels: int,
|
|
||||||
num_blocks: List[int],
|
|
||||||
block_type: List[str],
|
|
||||||
down_block_type: List[str],
|
|
||||||
block_args: List[Dict[str, Any]],
|
|
||||||
use_fp16: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.model_channels = model_channels
|
|
||||||
self.num_blocks = num_blocks
|
|
||||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
|
||||||
|
|
||||||
self.input_layer = SparseLinear(in_channels, model_channels[0])
|
|
||||||
self.to_latent = SparseLinear(model_channels[-1], 2 * latent_channels)
|
|
||||||
|
|
||||||
self.blocks = nn.ModuleList([])
|
|
||||||
for i in range(len(num_blocks)):
|
|
||||||
self.blocks.append(nn.ModuleList([]))
|
|
||||||
for j in range(num_blocks[i]):
|
|
||||||
self.blocks[-1].append(
|
|
||||||
globals()[block_type[i]](
|
|
||||||
model_channels[i],
|
|
||||||
**block_args[i],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if i < len(num_blocks) - 1:
|
|
||||||
self.blocks[-1].append(
|
|
||||||
globals()[down_block_type[i]](
|
|
||||||
model_channels[i],
|
|
||||||
model_channels[i+1],
|
|
||||||
**block_args[i],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self) -> torch.device:
|
|
||||||
return next(self.parameters()).device
|
|
||||||
|
|
||||||
def forward(self, x: SparseTensor, sample_posterior=False, return_raw=False):
|
|
||||||
h = self.input_layer(x)
|
|
||||||
h = h.type(self.dtype)
|
|
||||||
for i, res in enumerate(self.blocks):
|
|
||||||
for j, block in enumerate(res):
|
|
||||||
h = block(h)
|
|
||||||
h = h.type(x.dtype)
|
|
||||||
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
|
||||||
h = self.to_latent(h)
|
|
||||||
|
|
||||||
# Sample from the posterior distribution
|
|
||||||
mean, logvar = h.feats.chunk(2, dim=-1)
|
|
||||||
if sample_posterior:
|
|
||||||
std = torch.exp(0.5 * logvar)
|
|
||||||
z = mean + std * torch.randn_like(std)
|
|
||||||
else:
|
|
||||||
z = mean
|
|
||||||
z = h.replace(z)
|
|
||||||
|
|
||||||
if return_raw:
|
|
||||||
return z, mean, logvar
|
|
||||||
else:
|
|
||||||
return z
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_channels: List[int],
|
|
||||||
latent_channels: int,
|
|
||||||
num_blocks: List[int],
|
|
||||||
block_type: List[str],
|
|
||||||
down_block_type: List[str],
|
|
||||||
block_args: List[Dict[str, Any]],
|
|
||||||
use_fp16: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
6,
|
|
||||||
model_channels,
|
|
||||||
latent_channels,
|
|
||||||
num_blocks,
|
|
||||||
block_type,
|
|
||||||
down_block_type,
|
|
||||||
block_args,
|
|
||||||
use_fp16,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, vertices: SparseTensor, intersected: SparseTensor, sample_posterior=False, return_raw=False):
|
|
||||||
x = vertices.replace(torch.cat([
|
|
||||||
vertices.feats - 0.5,
|
|
||||||
intersected.feats.float() - 0.5,
|
|
||||||
], dim=1))
|
|
||||||
return super().forward(x, sample_posterior, return_raw)
|
|
||||||
|
|
||||||
class SparseUnetVaeDecoder(nn.Module):
|
class SparseUnetVaeDecoder(nn.Module):
|
||||||
"""
|
|
||||||
Sparse Swin Transformer Unet VAE model.
|
|
||||||
"""
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
@ -1218,10 +1113,10 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
|
|||||||
N = coords.shape[0]
|
N = coords.shape[0]
|
||||||
# compute flat keys for all coords (prepend batch 0 same as original code)
|
# compute flat keys for all coords (prepend batch 0 same as original code)
|
||||||
b = torch.zeros((N,), dtype=torch.long, device=device)
|
b = torch.zeros((N,), dtype=torch.long, device=device)
|
||||||
x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long()
|
x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32)
|
||||||
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
||||||
flat_keys = b * (W * H * D) + x * (H * D) + y * D + z
|
flat_keys = b * (W * H * D) + x * (H * D) + y * D + z
|
||||||
values = torch.arange(N, dtype=torch.long, device=device)
|
values = torch.arange(N, dtype=torch.int32, device=device)
|
||||||
DEFAULT_VAL = 0xffffffff # sentinel used in original code
|
DEFAULT_VAL = 0xffffffff # sentinel used in original code
|
||||||
return TorchHashMap(flat_keys, values, DEFAULT_VAL)
|
return TorchHashMap(flat_keys, values, DEFAULT_VAL)
|
||||||
|
|
||||||
@ -1295,13 +1190,12 @@ def flexible_dual_grid_to_mesh(
|
|||||||
|
|
||||||
# Extract mesh
|
# Extract mesh
|
||||||
N = dual_vertices.shape[0]
|
N = dual_vertices.shape[0]
|
||||||
mesh_vertices = (coords.float() + dual_vertices) / (2 * N) - 0.5
|
|
||||||
|
|
||||||
if hashmap_builder is None:
|
if hashmap_builder is None:
|
||||||
# build local TorchHashMap
|
# build local TorchHashMap
|
||||||
device = coords.device
|
device = coords.device
|
||||||
b = torch.zeros((N,), dtype=torch.long, device=device)
|
b = torch.zeros((N,), dtype=torch.long, device=device)
|
||||||
x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long()
|
x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32)
|
||||||
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
||||||
flat_keys = b * (W * H * D) + x * (H * D) + y * D + z
|
flat_keys = b * (W * H * D) + x * (H * D) + y * D + z
|
||||||
values = torch.arange(N, dtype=torch.long, device=device)
|
values = torch.arange(N, dtype=torch.long, device=device)
|
||||||
@ -1316,9 +1210,9 @@ def flexible_dual_grid_to_mesh(
|
|||||||
M = connected_voxel.shape[0]
|
M = connected_voxel.shape[0]
|
||||||
# flatten connected voxel coords and lookup
|
# flatten connected voxel coords and lookup
|
||||||
conn_flat_b = torch.zeros((M * 4,), dtype=torch.long, device=coords.device)
|
conn_flat_b = torch.zeros((M * 4,), dtype=torch.long, device=coords.device)
|
||||||
conn_x = connected_voxel.reshape(-1, 3)[:, 0].long()
|
conn_x = connected_voxel.reshape(-1, 3)[:, 0].to(torch.int32)
|
||||||
conn_y = connected_voxel.reshape(-1, 3)[:, 1].long()
|
conn_y = connected_voxel.reshape(-1, 3)[:, 1].to(torch.int32)
|
||||||
conn_z = connected_voxel.reshape(-1, 3)[:, 2].long()
|
conn_z = connected_voxel.reshape(-1, 3)[:, 2].to(torch.int32)
|
||||||
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
||||||
conn_flat = conn_flat_b * (W * H * D) + conn_x * (H * D) + conn_y * D + conn_z
|
conn_flat = conn_flat_b * (W * H * D) + conn_x * (H * D) + conn_y * D + conn_z
|
||||||
|
|
||||||
@ -1526,17 +1420,18 @@ class Vae(nn.Module):
|
|||||||
channels=[512, 128, 32],
|
channels=[512, 128, 32],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def decode_shape_slat(self, slat, resolution: int):
|
def decode_shape_slat(self, slat, resolution: int):
|
||||||
self.shape_dec.set_resolution(resolution)
|
self.shape_dec.set_resolution(resolution)
|
||||||
device = comfy.model_management.get_torch_device()
|
|
||||||
self.shape_dec = self.shape_dec.to(device)
|
|
||||||
return self.shape_dec(slat, return_subs=True)
|
return self.shape_dec(slat, return_subs=True)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def decode_tex_slat(self, slat, subs):
|
def decode_tex_slat(self, slat, subs):
|
||||||
if self.txt_dec is None:
|
if self.txt_dec is None:
|
||||||
raise ValueError("Checkpoint doesn't include texture model")
|
raise ValueError("Checkpoint doesn't include texture model")
|
||||||
return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5
|
return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5
|
||||||
|
|
||||||
|
# shouldn't be called (placeholder)
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
@ -1546,17 +1441,4 @@ class Vae(nn.Module):
|
|||||||
):
|
):
|
||||||
meshes, subs = self.decode_shape_slat(shape_slat, resolution)
|
meshes, subs = self.decode_shape_slat(shape_slat, resolution)
|
||||||
tex_voxels = self.decode_tex_slat(tex_slat, subs)
|
tex_voxels = self.decode_tex_slat(tex_slat, subs)
|
||||||
out_mesh = []
|
return tex_voxels
|
||||||
for m, v in zip(meshes, tex_voxels):
|
|
||||||
out_mesh.append(
|
|
||||||
MeshWithVoxel(
|
|
||||||
m.vertices, m.faces,
|
|
||||||
origin = [-0.5, -0.5, -0.5],
|
|
||||||
voxel_size = 1 / resolution,
|
|
||||||
coords = v.coords[:, 1:],
|
|
||||||
attrs = v.feats,
|
|
||||||
voxel_shape = torch.Size([*v.shape, *v.spatial_shape]),
|
|
||||||
layout=self.pbr_attr_layout
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return out_mesh
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO, Types
|
from comfy_api.latest import ComfyExtension, IO, Types
|
||||||
from comfy.ldm.trellis2.vae import SparseTensor
|
from comfy.ldm.trellis2.vae import SparseTensor
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar, lanczos
|
||||||
import torch.nn.functional as TF
|
import torch.nn.functional as TF
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -102,9 +102,7 @@ def run_conditioning(model, image, mask, include_1024 = True, background_color =
|
|||||||
cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb)
|
cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb)
|
||||||
|
|
||||||
def prepare_tensor(img, size):
|
def prepare_tensor(img, size):
|
||||||
resized = torch.nn.functional.interpolate(
|
resized = lanczos(img.unsqueeze(0), size, size)
|
||||||
img.unsqueeze(0), size=(size, size), mode='bicubic', align_corners=False
|
|
||||||
)
|
|
||||||
return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
||||||
|
|
||||||
model_internal.image_size = 512
|
model_internal.image_size = 512
|
||||||
@ -148,10 +146,16 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, samples, structure_output, vae, resolution):
|
def execute(cls, samples, structure_output, vae, resolution):
|
||||||
|
|
||||||
|
patcher = vae.patcher
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
comfy.model_management.load_model_gpu(patcher)
|
||||||
|
|
||||||
vae = vae.first_stage_model
|
vae = vae.first_stage_model
|
||||||
decoded = structure_output.data.unsqueeze(1)
|
decoded = structure_output.data.unsqueeze(1)
|
||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
|
samples = samples.squeeze(-1).transpose(1, 2).to(device)
|
||||||
std = shape_slat_normalization["std"].to(samples)
|
std = shape_slat_normalization["std"].to(samples)
|
||||||
mean = shape_slat_normalization["mean"].to(samples)
|
mean = shape_slat_normalization["mean"].to(samples)
|
||||||
samples = SparseTensor(feats = samples, coords=coords)
|
samples = SparseTensor(feats = samples, coords=coords)
|
||||||
@ -179,10 +183,16 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, samples, structure_output, vae, shape_subs):
|
def execute(cls, samples, structure_output, vae, shape_subs):
|
||||||
|
|
||||||
|
patcher = vae.patcher
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
comfy.model_management.load_model_gpu(patcher)
|
||||||
|
|
||||||
vae = vae.first_stage_model
|
vae = vae.first_stage_model
|
||||||
decoded = structure_output.data.unsqueeze(1)
|
decoded = structure_output.data.unsqueeze(1)
|
||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
|
samples = samples.squeeze(-1).transpose(1, 2).to(device)
|
||||||
std = tex_slat_normalization["std"].to(samples)
|
std = tex_slat_normalization["std"].to(samples)
|
||||||
mean = tex_slat_normalization["mean"].to(samples)
|
mean = tex_slat_normalization["mean"].to(samples)
|
||||||
samples = SparseTensor(feats = samples, coords=coords)
|
samples = SparseTensor(feats = samples, coords=coords)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user