vae shape decode fixes

This commit is contained in:
Yousef Rafat 2026-02-25 04:26:33 +02:00
parent a2c8a7aab5
commit f31c2e1d1d
4 changed files with 35 additions and 16 deletions

View File

@ -802,7 +802,7 @@ class Trellis2(nn.Module):
# may remove the else if texture doesn't require special handling
batched_coords = coords
feats_flat = x
x = SparseTensor(feats=feats_flat, coords=batched_coords)
x = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32))
if mode == "shape_generation":
# TODO

View File

@ -1,11 +1,12 @@
import math
import torch
import torch.nn as nn
from typing import List, Any, Dict, Optional, overload, Union, Tuple
from fractions import Fraction
import torch.nn.functional as F
from dataclasses import dataclass
import numpy as np
import torch.nn as nn
import comfy.model_management
import torch.nn.functional as F
from fractions import Fraction
from dataclasses import dataclass
from typing import List, Any, Dict, Optional, overload, Union, Tuple
from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d
@ -55,11 +56,12 @@ def sparse_conv3d_forward(self, x):
Co, Kd, Kh, Kw, Ci = self.weight.shape
neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}'
neighbor_cache = x.get_spatial_cache(neighbor_cache_key)
x = x.to(self.weight.dtype).to(self.weight.device)
out, neighbor_cache_ = sparse_submanifold_conv3d(
x.feats,
x.coords,
torch.Size([*x.shape, *x.spatial_shape]),
x.spatial_shape,
self.weight,
self.bias,
neighbor_cache,
@ -100,7 +102,8 @@ class SparseConvNeXtBlock3d(nn.Module):
def _forward(self, x):
h = self.conv(x)
h = h.replace(self.norm(h.feats))
norm = self.norm.to(torch.float32)
h = h.replace(norm(h.feats))
h = h.replace(self.mlp(h.feats))
return h + x
@ -208,13 +211,15 @@ class SparseResBlockC2S3d(nn.Module):
def forward(self, x, subdiv = None):
if self.pred_subdiv:
subdiv = self.to_subdiv(x)
h = x.replace(self.norm1(x.feats))
norm1 = self.norm1.to(torch.float32)
norm2 = self.norm2.to(torch.float32)
h = x.replace(norm1(x.feats))
h = h.replace(F.silu(h.feats))
h = self.conv1(h)
subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None
h = self.updown(h, subdiv_binarized)
x = self.updown(x, subdiv_binarized)
h = h.replace(self.norm2(h.feats))
h = h.replace(norm2(h.feats))
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x)
@ -1139,6 +1144,9 @@ class SparseUnetVaeDecoder(nn.Module):
def forward(self, x: SparseTensor, guide_subs: Optional[List[SparseTensor]] = None, return_subs: bool = False) -> SparseTensor:
dtype = next(self.from_latent.parameters()).dtype
device = next(self.from_latent.parameters()).device
x.feats = x.feats.to(dtype).to(device)
h = self.from_latent(x)
h = h.type(self.dtype)
subs = []
@ -1152,7 +1160,7 @@ class SparseUnetVaeDecoder(nn.Module):
h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None)
else:
h = block(h)
h = h.type(x.dtype)
h = h.type(x.feats.dtype)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.output_layer(h)
if return_subs:
@ -1520,6 +1528,8 @@ class Vae(nn.Module):
def decode_shape_slat(self, slat, resolution: int):
self.shape_dec.set_resolution(resolution)
device = comfy.model_management.get_torch_device()
self.shape_dec = self.shape_dec.to(device)
return self.shape_dec(slat, return_subs=True)
def decode_tex_slat(self, slat, subs):

View File

@ -513,8 +513,8 @@ class VAE:
init_txt_model = True
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# TODO
self.memory_used_decode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model)
elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64:

View File

@ -1,8 +1,9 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types
from comfy.ldm.trellis2.vae import SparseTensor
from comfy.utils import ProgressBar
import torch.nn.functional as TF
import comfy.model_management
from comfy.utils import ProgressBar
from PIL import Image
import numpy as np
import torch
@ -135,6 +136,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
category="latent/3d",
inputs=[
IO.Latent.Input("samples"),
IO.Voxel.Input("structure_output"),
IO.Vae.Input("vae"),
IO.Int.Input("resolution", tooltip="Shape Generation Resolution"),
],
@ -145,11 +147,14 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
)
@classmethod
def execute(cls, samples, vae, resolution):
def execute(cls, samples, structure_output, vae, resolution):
vae = vae.first_stage_model
decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
samples = samples["samples"]
std = shape_slat_normalization["std"].to(samples)
mean = shape_slat_normalization["mean"].to(samples)
samples = SparseTensor(feats = samples, coords=coords)
samples = samples * std + mean
mesh, subs = vae.decode_shape_slat(samples, resolution)
@ -163,6 +168,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
category="latent/3d",
inputs=[
IO.Latent.Input("samples"),
IO.Voxel.Input("structure_output"),
IO.Vae.Input("vae"),
IO.AnyType.Input("shape_subs"),
],
@ -172,11 +178,14 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
)
@classmethod
def execute(cls, samples, vae, shape_subs):
def execute(cls, samples, structure_output, vae, shape_subs):
vae = vae.first_stage_model
decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
samples = samples["samples"]
std = tex_slat_normalization["std"].to(samples)
mean = tex_slat_normalization["mean"].to(samples)
samples = SparseTensor(feats = samples, coords=coords)
samples = samples * std + mean
mesh = vae.decode_tex_slat(samples, shape_subs)