From f31c2e1d1d7359f05995804f636444313b794068 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 25 Feb 2026 04:26:33 +0200 Subject: [PATCH] vae shape decode fixes --- comfy/ldm/trellis2/model.py | 2 +- comfy/ldm/trellis2/vae.py | 30 ++++++++++++++++++++---------- comfy/sd.py | 4 ++-- comfy_extras/nodes_trellis2.py | 15 ++++++++++++--- 4 files changed, 35 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index fb5276f94..45740faea 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -802,7 +802,7 @@ class Trellis2(nn.Module): # may remove the else if texture doesn't require special handling batched_coords = coords feats_flat = x - x = SparseTensor(feats=feats_flat, coords=batched_coords) + x = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) if mode == "shape_generation": # TODO diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 36e2f3df5..0b1975092 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -1,11 +1,12 @@ import math import torch -import torch.nn as nn -from typing import List, Any, Dict, Optional, overload, Union, Tuple -from fractions import Fraction -import torch.nn.functional as F -from dataclasses import dataclass import numpy as np +import torch.nn as nn +import comfy.model_management +import torch.nn.functional as F +from fractions import Fraction +from dataclasses import dataclass +from typing import List, Any, Dict, Optional, overload, Union, Tuple from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d @@ -55,11 +56,12 @@ def sparse_conv3d_forward(self, x): Co, Kd, Kh, Kw, Ci = self.weight.shape neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}' neighbor_cache = x.get_spatial_cache(neighbor_cache_key) + x = x.to(self.weight.dtype).to(self.weight.device) out, neighbor_cache_ = sparse_submanifold_conv3d( x.feats, x.coords, - torch.Size([*x.shape, *x.spatial_shape]), + x.spatial_shape, self.weight, self.bias, neighbor_cache, @@ -100,7 +102,8 @@ class SparseConvNeXtBlock3d(nn.Module): def _forward(self, x): h = self.conv(x) - h = h.replace(self.norm(h.feats)) + norm = self.norm.to(torch.float32) + h = h.replace(norm(h.feats)) h = h.replace(self.mlp(h.feats)) return h + x @@ -208,13 +211,15 @@ class SparseResBlockC2S3d(nn.Module): def forward(self, x, subdiv = None): if self.pred_subdiv: subdiv = self.to_subdiv(x) - h = x.replace(self.norm1(x.feats)) + norm1 = self.norm1.to(torch.float32) + norm2 = self.norm2.to(torch.float32) + h = x.replace(norm1(x.feats)) h = h.replace(F.silu(h.feats)) h = self.conv1(h) subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None h = self.updown(h, subdiv_binarized) x = self.updown(x, subdiv_binarized) - h = h.replace(self.norm2(h.feats)) + h = h.replace(norm2(h.feats)) h = h.replace(F.silu(h.feats)) h = self.conv2(h) h = h + self.skip_connection(x) @@ -1139,6 +1144,9 @@ class SparseUnetVaeDecoder(nn.Module): def forward(self, x: SparseTensor, guide_subs: Optional[List[SparseTensor]] = None, return_subs: bool = False) -> SparseTensor: + dtype = next(self.from_latent.parameters()).dtype + device = next(self.from_latent.parameters()).device + x.feats = x.feats.to(dtype).to(device) h = self.from_latent(x) h = h.type(self.dtype) subs = [] @@ -1152,7 +1160,7 @@ class SparseUnetVaeDecoder(nn.Module): h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None) else: h = block(h) - h = h.type(x.dtype) + h = h.type(x.feats.dtype) h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) h = self.output_layer(h) if return_subs: @@ -1520,6 +1528,8 @@ class Vae(nn.Module): def decode_shape_slat(self, slat, resolution: int): self.shape_dec.set_resolution(resolution) + device = comfy.model_management.get_torch_device() + self.shape_dec = self.shape_dec.to(device) return self.shape_dec(slat, return_subs=True) def decode_tex_slat(self, slat, subs): diff --git a/comfy/sd.py b/comfy/sd.py index fecd16c88..f9898b0de 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -513,8 +513,8 @@ class VAE: init_txt_model = True self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] # TODO - self.memory_used_decode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model) elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index f40ff5161..96510e916 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,8 +1,9 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types +from comfy.ldm.trellis2.vae import SparseTensor +from comfy.utils import ProgressBar import torch.nn.functional as TF import comfy.model_management -from comfy.utils import ProgressBar from PIL import Image import numpy as np import torch @@ -135,6 +136,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): category="latent/3d", inputs=[ IO.Latent.Input("samples"), + IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), IO.Int.Input("resolution", tooltip="Shape Generation Resolution"), ], @@ -145,11 +147,14 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples, vae, resolution): + def execute(cls, samples, structure_output, vae, resolution): vae = vae.first_stage_model + decoded = structure_output.data.unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] std = shape_slat_normalization["std"].to(samples) mean = shape_slat_normalization["mean"].to(samples) + samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean mesh, subs = vae.decode_shape_slat(samples, resolution) @@ -163,6 +168,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): category="latent/3d", inputs=[ IO.Latent.Input("samples"), + IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), IO.AnyType.Input("shape_subs"), ], @@ -172,11 +178,14 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples, vae, shape_subs): + def execute(cls, samples, structure_output, vae, shape_subs): vae = vae.first_stage_model + decoded = structure_output.data.unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] std = tex_slat_normalization["std"].to(samples) mean = tex_slat_normalization["mean"].to(samples) + samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean mesh = vae.decode_tex_slat(samples, shape_subs)