Cleanup VAE code some

This commit is contained in:
kijai 2026-06-27 00:43:32 +03:00
parent 41f5f4b2c0
commit 36e8f62dd4
3 changed files with 68 additions and 63 deletions

View File

@ -249,6 +249,53 @@ class TripoSplat(LatentFormat):
def process_out(self, latent):
return latent
class Trellis2(LatentFormat):
latent_channels = 32
class Trellis2SLAT(Trellis2):
# Sparse structured latent: per-token feats [N, 32]. process_out denormalizes
# the decoded feats (latent * std + mean); subclasses carry each space's stats.
latents_mean = None
latents_std = None
def process_in(self, latent):
mean = self.latents_mean.to(latent.device, latent.dtype)
std = self.latents_std.to(latent.device, latent.dtype)
return (latent - mean) / std
def process_out(self, latent):
mean = self.latents_mean.to(latent.device, latent.dtype)
std = self.latents_std.to(latent.device, latent.dtype)
return latent * std + mean
class Trellis2ShapeSLAT(Trellis2SLAT):
latents_mean = torch.tensor([
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
-0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
-0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
])[None]
latents_std = torch.tensor([
5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
])[None]
class Trellis2TexSLAT(Trellis2SLAT):
latents_mean = torch.tensor([
3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075,
0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149,
-1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717,
1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862
])[None]
latents_std = torch.tensor([
2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822,
2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588,
2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999,
2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190
])[None]
class Mochi(LatentFormat):
latent_channels = 12
latent_dimensions = 3
@ -770,10 +817,6 @@ class Hunyuan3Dv2_1(LatentFormat):
latent_channels = 64
latent_dimensions = 1
class Trellis2(LatentFormat):
latent_channels = 32
class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@ -1114,6 +1114,15 @@ class VAE:
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
def prepare_decode(self, sample_shape, memory_required=None):
"""For VAEs whose real decode entry point bypasses decode()"""
if memory_required is None:
memory_required = self.memory_used_decode(sample_shape, self.vae_dtype)
memory_required = max(1, int(memory_required))
model_management.load_models_gpu([self.patcher], memory_required=memory_required, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
return max(1, int(free_memory / memory_required))
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile

View File

@ -5,6 +5,7 @@ from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_point
from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
import comfy.latent_formats
import comfy.model_management
import comfy.utils
import folder_paths
@ -18,57 +19,12 @@ ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
NAFModel = io.Custom("NAF_MODEL")
def prepare_trellis_vae_for_decode(vae, sample_shape):
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
if len(sample_shape) == 5:
memory_required *= max(1, int(sample_shape[4]))
memory_required = max(1, int(memory_required))
device = comfy.model_management.get_torch_device()
comfy.model_management.load_models_gpu(
[vae.patcher],
memory_required=memory_required,
force_full_load=getattr(vae, "disable_offload", False),
)
free_memory = vae.patcher.get_free_memory(device)
batch_number = max(1, int(free_memory / memory_required))
return batch_number
shape_slat_normalization = {
"mean": torch.tensor([
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
-0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
-0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
])[None],
"std": torch.tensor([
5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
])[None]
}
tex_slat_normalization = {
"mean": torch.tensor([
3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075,
0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149,
-1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717,
1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862
])[None],
"std": torch.tensor([
2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822,
2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588,
2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999,
2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190
])[None]
}
shape_slat_format = comfy.latent_formats.Trellis2ShapeSLAT()
tex_slat_format = comfy.latent_formats.Trellis2TexSLAT()
def shape_norm(shape_latent, coords):
std = shape_slat_normalization["std"].to(shape_latent)
mean = shape_slat_normalization["mean"].to(shape_latent)
samples = SparseTensor(feats = shape_latent, coords=coords)
samples = samples * std + mean
return samples
feats = shape_slat_format.process_out(shape_latent)
return SparseTensor(feats=feats, coords=coords)
def infer_batched_coord_layout(coords):
@ -177,7 +133,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
sample_tensor = samples["samples"]
device = comfy.model_management.get_torch_device()
coords = samples["coords"]
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
vae.prepare_decode(sample_tensor.shape)
trellis_vae = vae.first_stage_model
coord_counts = samples.get("coord_counts")
@ -243,7 +199,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
sample_tensor = samples["samples"]
device = comfy.model_management.get_torch_device()
coords = samples["coords"]
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
vae.prepare_decode(sample_tensor.shape)
trellis_vae = vae.first_stage_model
coord_counts = samples.get("coord_counts")
model_frame = samples.get("model_frame", "y_up")
@ -252,16 +208,13 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
samples = samples["samples"]
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
samples = samples.to(device)
std = tex_slat_normalization["std"].to(samples)
mean = tex_slat_normalization["mean"].to(samples)
samples = SparseTensor(feats = samples, coords=coords.to(device))
samples = samples * std + mean
feats = tex_slat_format.process_out(samples)
samples = SparseTensor(feats=feats, coords=coords.to(device))
voxel = trellis_vae.decode_tex_slat(samples.to(vae.vae_dtype), shape_subdivides)
# Keep all decoded channels. The texture VAE emits 6: base_color (0:3),
# metallic (3), roughness (4), alpha (5) — all in [0, 1]. Vertex-color
# consumers (PaintMesh) slice [:3]; BakeTextureFromVoxel uses the full
# PBR set. Older 3-channel checkpoints pass through unchanged.
# consumers (PaintMesh) slice [:3]
color_feats = voxel.feats
voxel_coords = voxel.coords
@ -316,7 +269,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
resolution = int(resolution)
sample_tensor = samples["samples"]
sample_tensor = sample_tensor[:, :8]
batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
batch_number = vae.prepare_decode(sample_tensor.shape)
shape_vae = vae.first_stage_model
load_device = comfy.model_management.get_torch_device()
decoded_batches = []
@ -381,7 +334,7 @@ class Trellis2UpsampleStage(IO.ComfyNode):
@classmethod
def execute(cls, positive, negative, shape_latent, vae, target_resolution, max_tokens):
device = comfy.model_management.get_torch_device()
prepare_trellis_vae_for_decode(vae, shape_latent["samples"].shape)
vae.prepare_decode(shape_latent["samples"].shape)
coord_counts = shape_latent.get("coord_counts")
shape_vae = vae.first_stage_model