removed unnecessary code + optimizations + progres

This commit is contained in:
Yousef Rafat 2026-02-25 23:54:03 +02:00
parent f31c2e1d1d
commit 39270fdca9
3 changed files with 75 additions and 264 deletions

View File

@ -2,7 +2,7 @@
import math import math
import torch import torch
from typing import Dict, Callable from typing import Callable
import logging import logging
NO_TRITON = False NO_TRITON = False
@ -201,13 +201,13 @@ class TorchHashMap:
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
device = keys.device device = keys.device
# use long for searchsorted # use long for searchsorted
self.sorted_keys, order = torch.sort(keys.long()) self.sorted_keys, order = torch.sort(keys.to(torch.long))
self.sorted_vals = values.long()[order] self.sorted_vals = values.to(torch.long)[order]
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device) self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
self._n = self.sorted_keys.numel() self._n = self.sorted_keys.numel()
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
flat = flat_keys.long() flat = flat_keys.to(torch.long)
if self._n == 0: if self._n == 0:
return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
idx = torch.searchsorted(self.sorted_keys, flat) idx = torch.searchsorted(self.sorted_keys, flat)
@ -225,44 +225,35 @@ def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map):
device = neighbor_map.device device = neighbor_map.device
N, V = neighbor_map.shape N, V = neighbor_map.shape
sentinel = UINT32_SENTINEL
neigh = neighbor_map.to(torch.long) neigh_map_T = neighbor_map.t().reshape(-1)
sentinel = torch.tensor(UINT32_SENTINEL, dtype=torch.long, device=device)
neigh_map_T = neigh.t().reshape(-1)
neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32) neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32)
mask = (neigh != sentinel).to(torch.long) mask = (neighbor_map != sentinel).to(torch.long)
gray_code = torch.zeros(N, dtype=torch.long, device=device)
powers = (1 << torch.arange(V, dtype=torch.long, device=device)) for v in range(V):
gray_code |= (mask[:, v] << v)
gray_long = (mask * powers).sum(dim=1) binary_code = gray_code.clone()
gray_code = gray_long.to(torch.int32)
binary_long = gray_long.clone()
for v in range(1, V): for v in range(1, V):
binary_long ^= (gray_long >> v) binary_code ^= (gray_code >> v)
binary_code = binary_long.to(torch.int32)
sorted_idx = torch.argsort(binary_code) sorted_idx = torch.argsort(binary_code)
prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T.to(torch.int32), dim=0) # (V*N,) prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T, dim=0)
total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0 total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0
if total_valid_signal > 0: if total_valid_signal > 0:
pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0]
to = (prefix_sum_neighbor_mask[pos] - 1).long()
valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device) valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device) valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0]
to = (prefix_sum_neighbor_mask[pos] - 1).to(torch.long)
valid_signal_i[to] = (pos % N).to(torch.long) valid_signal_i[to] = (pos % N).to(torch.long)
valid_signal_o[to] = neigh_map_T[pos].to(torch.long) valid_signal_o[to] = neigh_map_T[pos].to(torch.long)
else: else:
valid_signal_i = torch.empty((0,), dtype=torch.long, device=device) valid_signal_i = torch.empty((0,), dtype=torch.long, device=device)
@ -272,9 +263,7 @@ def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map):
seg[0] = 0 seg[0] = 0
if V > 0: if V > 0:
idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1 idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1
seg[1:] = prefix_sum_neighbor_mask[idxs].to(torch.long) seg[1:] = prefix_sum_neighbor_mask[idxs]
else:
pass
return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg
@ -295,40 +284,41 @@ def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor:
def neighbor_map_post_process_for_masked_implicit_gemm_2( def neighbor_map_post_process_for_masked_implicit_gemm_2(
gray_code: torch.Tensor, # [N], int32-like (non-negative) gray_code: torch.Tensor,
sorted_idx: torch.Tensor, # [N], long (indexing into gray_code) sorted_idx: torch.Tensor,
block_size: int block_size: int
): ):
device = gray_code.device device = gray_code.device
N = gray_code.numel() N = gray_code.numel()
# num of blocks (same as CUDA)
num_blocks = (N + block_size - 1) // block_size num_blocks = (N + block_size - 1) // block_size
# Ensure dtypes
gray_long = gray_code.to(torch.int64) # safer to OR in 64-bit then cast
sorted_idx = sorted_idx.to(torch.long)
# 1) Group gray_code by blocks and compute OR across each block
# pad the last block with zeros if necessary so we can reshape
pad = num_blocks * block_size - N pad = num_blocks * block_size - N
if pad > 0: if pad > 0:
pad_vals = torch.zeros((pad,), dtype=torch.int64, device=device) pad_vals = torch.zeros((pad,), dtype=torch.int32, device=device)
gray_padded = torch.cat([gray_long[sorted_idx], pad_vals], dim=0) gray_padded = torch.cat([gray_code[sorted_idx], pad_vals], dim=0)
else: else:
gray_padded = gray_long[sorted_idx] gray_padded = gray_code[sorted_idx]
# reshape to (num_blocks, block_size) and compute bitwise_or across dim=1 gray_blocks = gray_padded.view(num_blocks, block_size)
gray_blocks = gray_padded.view(num_blocks, block_size) # each row = block entries
# reduce with bitwise_or reduced_code = gray_blocks
reduced_code = gray_blocks[:, 0].clone() while reduced_code.shape[1] > 1:
for i in range(1, block_size): half = reduced_code.shape[1] // 2
reduced_code |= gray_blocks[:, i] remainder = reduced_code.shape[1] % 2
reduced_code = reduced_code.to(torch.int32) # match CUDA int32
left = reduced_code[:, :half * 2:2]
right = reduced_code[:, 1:half * 2:2]
merged = left | right
if remainder:
reduced_code = torch.cat([merged, reduced_code[:, -1:]], dim=1)
else:
reduced_code = merged
reduced_code = reduced_code.squeeze(1)
seglen_counts = _popcount_int32_tensor(reduced_code).to(torch.int32)
# 2) compute seglen (popcount per reduced_code) and seg (prefix sum)
seglen_counts = _popcount_int32_tensor(reduced_code.to(torch.int64)).to(torch.int32) # [num_blocks]
# seg: length num_blocks+1, seg[0]=0, seg[i+1]=cumsum(seglen_counts) up to i
seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device) seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device)
seg[0] = 0 seg[0] = 0
if num_blocks > 0: if num_blocks > 0:
@ -336,30 +326,20 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2(
total = int(seg[-1].item()) total = int(seg[-1].item())
# 3) scatter — produce valid_kernel_idx as concatenated ascending set-bit positions for each reduced_code row
if total == 0: if total == 0:
valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device) return torch.empty((0,), dtype=torch.int32, device=device), seg
return valid_kernel_idx, seg
max_val = int(reduced_code.max().item()) V = int(reduced_code.max().item()).bit_length() if reduced_code.max() > 0 else 0
V = max_val.bit_length() if max_val > 0 else 0
# If you know V externally, pass it instead or set here explicitly.
if V == 0: if V == 0:
# no bits set anywhere return torch.empty((0,), dtype=torch.int32, device=device), seg
valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device)
return valid_kernel_idx, seg
# build mask of shape (num_blocks, V): True where bit is set bit_pos = torch.arange(0, V, dtype=torch.int32, device=device)
bit_pos = torch.arange(0, V, dtype=torch.int64, device=device) # [V] shifted = reduced_code.unsqueeze(1) >> bit_pos.unsqueeze(0)
# shifted = reduced_code[:, None] >> bit_pos[None, :] bits = (shifted & 1).to(torch.bool)
shifted = reduced_code.to(torch.int64).unsqueeze(1) >> bit_pos.unsqueeze(0)
bits = (shifted & 1).to(torch.bool) # (num_blocks, V)
positions = bit_pos.unsqueeze(0).expand(num_blocks, V) positions = bit_pos.unsqueeze(0).expand(num_blocks, V)
valid_kernel_idx = positions[bits].to(torch.int32).contiguous()
valid_positions = positions[bits]
valid_kernel_idx = valid_positions.to(torch.int32).contiguous()
return valid_kernel_idx, seg return valid_kernel_idx, seg
@ -425,35 +405,6 @@ def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache
return out, neighbor return out, neighbor
class Voxel:
def __init__(
self,
origin: list,
voxel_size: float,
coords: torch.Tensor = None,
attrs: torch.Tensor = None,
layout = None,
device: torch.device = None
):
if layout is None:
layout = {}
self.origin = torch.tensor(origin, dtype=torch.float32, device=device)
self.voxel_size = voxel_size
self.coords = coords
self.attrs = attrs
self.layout = layout
self.device = device
@property
def position(self):
return (self.coords + 0.5) * self.voxel_size + self.origin[None, :]
def split_attrs(self):
return {
k: self.attrs[:, self.layout[k]]
for k in self.layout
}
class Mesh: class Mesh:
def __init__(self, def __init__(self,
vertices, vertices,
@ -480,35 +431,3 @@ class Mesh:
def cpu(self): def cpu(self):
return self.to('cpu') return self.to('cpu')
class MeshWithVoxel(Mesh, Voxel):
def __init__(self,
vertices: torch.Tensor,
faces: torch.Tensor,
origin: list,
voxel_size: float,
coords: torch.Tensor,
attrs: torch.Tensor,
voxel_shape: torch.Size,
layout: Dict = {},
):
self.vertices = vertices.float()
self.faces = faces.int()
self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device)
self.voxel_size = voxel_size
self.coords = coords
self.attrs = attrs
self.voxel_shape = voxel_shape
self.layout = layout
def to(self, device, non_blocking=False):
return MeshWithVoxel(
self.vertices.to(device, non_blocking=non_blocking),
self.faces.to(device, non_blocking=non_blocking),
self.origin.tolist(),
self.voxel_size,
self.coords.to(device, non_blocking=non_blocking),
self.attrs.to(device, non_blocking=non_blocking),
self.voxel_shape,
self.layout,
)

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
from fractions import Fraction from fractions import Fraction
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Any, Dict, Optional, overload, Union, Tuple from typing import List, Any, Dict, Optional, overload, Union, Tuple
from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, sparse_submanifold_conv3d
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
@ -210,6 +210,8 @@ class SparseResBlockC2S3d(nn.Module):
def forward(self, x, subdiv = None): def forward(self, x, subdiv = None):
if self.pred_subdiv: if self.pred_subdiv:
dtype = next(self.to_subdiv.parameters()).dtype
x = x.to(dtype)
subdiv = self.to_subdiv(x) subdiv = self.to_subdiv(x)
norm1 = self.norm1.to(torch.float32) norm1 = self.norm1.to(torch.float32)
norm2 = self.norm2.to(torch.float32) norm2 = self.norm2.to(torch.float32)
@ -987,114 +989,7 @@ def convert_module_to_f16(l):
for p in l.parameters(): for p in l.parameters():
p.data = p.data.half() p.data = p.data.half()
class SparseUnetVaeEncoder(nn.Module):
"""
Sparse Swin Transformer Unet VAE model.
"""
def __init__(
self,
in_channels: int,
model_channels: List[int],
latent_channels: int,
num_blocks: List[int],
block_type: List[str],
down_block_type: List[str],
block_args: List[Dict[str, Any]],
use_fp16: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.num_blocks = num_blocks
self.dtype = torch.float16 if use_fp16 else torch.float32
self.input_layer = SparseLinear(in_channels, model_channels[0])
self.to_latent = SparseLinear(model_channels[-1], 2 * latent_channels)
self.blocks = nn.ModuleList([])
for i in range(len(num_blocks)):
self.blocks.append(nn.ModuleList([]))
for j in range(num_blocks[i]):
self.blocks[-1].append(
globals()[block_type[i]](
model_channels[i],
**block_args[i],
)
)
if i < len(num_blocks) - 1:
self.blocks[-1].append(
globals()[down_block_type[i]](
model_channels[i],
model_channels[i+1],
**block_args[i],
)
)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def forward(self, x: SparseTensor, sample_posterior=False, return_raw=False):
h = self.input_layer(x)
h = h.type(self.dtype)
for i, res in enumerate(self.blocks):
for j, block in enumerate(res):
h = block(h)
h = h.type(x.dtype)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.to_latent(h)
# Sample from the posterior distribution
mean, logvar = h.feats.chunk(2, dim=-1)
if sample_posterior:
std = torch.exp(0.5 * logvar)
z = mean + std * torch.randn_like(std)
else:
z = mean
z = h.replace(z)
if return_raw:
return z, mean, logvar
else:
return z
class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder):
def __init__(
self,
model_channels: List[int],
latent_channels: int,
num_blocks: List[int],
block_type: List[str],
down_block_type: List[str],
block_args: List[Dict[str, Any]],
use_fp16: bool = False,
):
super().__init__(
6,
model_channels,
latent_channels,
num_blocks,
block_type,
down_block_type,
block_args,
use_fp16,
)
def forward(self, vertices: SparseTensor, intersected: SparseTensor, sample_posterior=False, return_raw=False):
x = vertices.replace(torch.cat([
vertices.feats - 0.5,
intersected.feats.float() - 0.5,
], dim=1))
return super().forward(x, sample_posterior, return_raw)
class SparseUnetVaeDecoder(nn.Module): class SparseUnetVaeDecoder(nn.Module):
"""
Sparse Swin Transformer Unet VAE model.
"""
def __init__( def __init__(
self, self,
out_channels: int, out_channels: int,
@ -1218,10 +1113,10 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
N = coords.shape[0] N = coords.shape[0]
# compute flat keys for all coords (prepend batch 0 same as original code) # compute flat keys for all coords (prepend batch 0 same as original code)
b = torch.zeros((N,), dtype=torch.long, device=device) b = torch.zeros((N,), dtype=torch.long, device=device)
x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long() x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32)
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
flat_keys = b * (W * H * D) + x * (H * D) + y * D + z flat_keys = b * (W * H * D) + x * (H * D) + y * D + z
values = torch.arange(N, dtype=torch.long, device=device) values = torch.arange(N, dtype=torch.int32, device=device)
DEFAULT_VAL = 0xffffffff # sentinel used in original code DEFAULT_VAL = 0xffffffff # sentinel used in original code
return TorchHashMap(flat_keys, values, DEFAULT_VAL) return TorchHashMap(flat_keys, values, DEFAULT_VAL)
@ -1295,13 +1190,12 @@ def flexible_dual_grid_to_mesh(
# Extract mesh # Extract mesh
N = dual_vertices.shape[0] N = dual_vertices.shape[0]
mesh_vertices = (coords.float() + dual_vertices) / (2 * N) - 0.5
if hashmap_builder is None: if hashmap_builder is None:
# build local TorchHashMap # build local TorchHashMap
device = coords.device device = coords.device
b = torch.zeros((N,), dtype=torch.long, device=device) b = torch.zeros((N,), dtype=torch.long, device=device)
x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long() x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32)
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
flat_keys = b * (W * H * D) + x * (H * D) + y * D + z flat_keys = b * (W * H * D) + x * (H * D) + y * D + z
values = torch.arange(N, dtype=torch.long, device=device) values = torch.arange(N, dtype=torch.long, device=device)
@ -1316,9 +1210,9 @@ def flexible_dual_grid_to_mesh(
M = connected_voxel.shape[0] M = connected_voxel.shape[0]
# flatten connected voxel coords and lookup # flatten connected voxel coords and lookup
conn_flat_b = torch.zeros((M * 4,), dtype=torch.long, device=coords.device) conn_flat_b = torch.zeros((M * 4,), dtype=torch.long, device=coords.device)
conn_x = connected_voxel.reshape(-1, 3)[:, 0].long() conn_x = connected_voxel.reshape(-1, 3)[:, 0].to(torch.int32)
conn_y = connected_voxel.reshape(-1, 3)[:, 1].long() conn_y = connected_voxel.reshape(-1, 3)[:, 1].to(torch.int32)
conn_z = connected_voxel.reshape(-1, 3)[:, 2].long() conn_z = connected_voxel.reshape(-1, 3)[:, 2].to(torch.int32)
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
conn_flat = conn_flat_b * (W * H * D) + conn_x * (H * D) + conn_y * D + conn_z conn_flat = conn_flat_b * (W * H * D) + conn_x * (H * D) + conn_y * D + conn_z
@ -1526,17 +1420,18 @@ class Vae(nn.Module):
channels=[512, 128, 32], channels=[512, 128, 32],
) )
@torch.no_grad()
def decode_shape_slat(self, slat, resolution: int): def decode_shape_slat(self, slat, resolution: int):
self.shape_dec.set_resolution(resolution) self.shape_dec.set_resolution(resolution)
device = comfy.model_management.get_torch_device()
self.shape_dec = self.shape_dec.to(device)
return self.shape_dec(slat, return_subs=True) return self.shape_dec(slat, return_subs=True)
@torch.no_grad()
def decode_tex_slat(self, slat, subs): def decode_tex_slat(self, slat, subs):
if self.txt_dec is None: if self.txt_dec is None:
raise ValueError("Checkpoint doesn't include texture model") raise ValueError("Checkpoint doesn't include texture model")
return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5 return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5
# shouldn't be called (placeholder)
@torch.no_grad() @torch.no_grad()
def decode( def decode(
self, self,
@ -1546,17 +1441,4 @@ class Vae(nn.Module):
): ):
meshes, subs = self.decode_shape_slat(shape_slat, resolution) meshes, subs = self.decode_shape_slat(shape_slat, resolution)
tex_voxels = self.decode_tex_slat(tex_slat, subs) tex_voxels = self.decode_tex_slat(tex_slat, subs)
out_mesh = [] return tex_voxels
for m, v in zip(meshes, tex_voxels):
out_mesh.append(
MeshWithVoxel(
m.vertices, m.faces,
origin = [-0.5, -0.5, -0.5],
voxel_size = 1 / resolution,
coords = v.coords[:, 1:],
attrs = v.feats,
voxel_shape = torch.Size([*v.shape, *v.spatial_shape]),
layout=self.pbr_attr_layout
)
)
return out_mesh

View File

@ -1,7 +1,7 @@
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types from comfy_api.latest import ComfyExtension, IO, Types
from comfy.ldm.trellis2.vae import SparseTensor from comfy.ldm.trellis2.vae import SparseTensor
from comfy.utils import ProgressBar from comfy.utils import ProgressBar, lanczos
import torch.nn.functional as TF import torch.nn.functional as TF
import comfy.model_management import comfy.model_management
from PIL import Image from PIL import Image
@ -102,9 +102,7 @@ def run_conditioning(model, image, mask, include_1024 = True, background_color =
cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb) cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb)
def prepare_tensor(img, size): def prepare_tensor(img, size):
resized = torch.nn.functional.interpolate( resized = lanczos(img.unsqueeze(0), size, size)
img.unsqueeze(0), size=(size, size), mode='bicubic', align_corners=False
)
return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device)
model_internal.image_size = 512 model_internal.image_size = 512
@ -148,10 +146,16 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
@classmethod @classmethod
def execute(cls, samples, structure_output, vae, resolution): def execute(cls, samples, structure_output, vae, resolution):
patcher = vae.patcher
device = comfy.model_management.get_torch_device()
comfy.model_management.load_model_gpu(patcher)
vae = vae.first_stage_model vae = vae.first_stage_model
decoded = structure_output.data.unsqueeze(1) decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
samples = samples["samples"] samples = samples["samples"]
samples = samples.squeeze(-1).transpose(1, 2).to(device)
std = shape_slat_normalization["std"].to(samples) std = shape_slat_normalization["std"].to(samples)
mean = shape_slat_normalization["mean"].to(samples) mean = shape_slat_normalization["mean"].to(samples)
samples = SparseTensor(feats = samples, coords=coords) samples = SparseTensor(feats = samples, coords=coords)
@ -179,10 +183,16 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
@classmethod @classmethod
def execute(cls, samples, structure_output, vae, shape_subs): def execute(cls, samples, structure_output, vae, shape_subs):
patcher = vae.patcher
device = comfy.model_management.get_torch_device()
comfy.model_management.load_model_gpu(patcher)
vae = vae.first_stage_model vae = vae.first_stage_model
decoded = structure_output.data.unsqueeze(1) decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
samples = samples["samples"] samples = samples["samples"]
samples = samples.squeeze(-1).transpose(1, 2).to(device)
std = tex_slat_normalization["std"].to(samples) std = tex_slat_normalization["std"].to(samples)
mean = tex_slat_normalization["mean"].to(samples) mean = tex_slat_normalization["mean"].to(samples)
samples = SparseTensor(feats = samples, coords=coords) samples = SparseTensor(feats = samples, coords=coords)