diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 1ff9ada9d..e505709e3 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -249,6 +249,53 @@ class TripoSplat(LatentFormat): def process_out(self, latent): return latent +class Trellis2(LatentFormat): + latent_channels = 32 + +class Trellis2SLAT(Trellis2): + # Sparse structured latent: per-token feats [N, 32]. process_out denormalizes + # the decoded feats (latent * std + mean); subclasses carry each space's stats. + latents_mean = None + latents_std = None + + def process_in(self, latent): + mean = self.latents_mean.to(latent.device, latent.dtype) + std = self.latents_std.to(latent.device, latent.dtype) + return (latent - mean) / std + + def process_out(self, latent): + mean = self.latents_mean.to(latent.device, latent.dtype) + std = self.latents_std.to(latent.device, latent.dtype) + return latent * std + mean + +class Trellis2ShapeSLAT(Trellis2SLAT): + latents_mean = torch.tensor([ + 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218, + -0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944, + 0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667, + -0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665 + ])[None] + latents_std = torch.tensor([ + 5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004, + 5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578, + 4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194, + 5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691 + ])[None] + +class Trellis2TexSLAT(Trellis2SLAT): + latents_mean = torch.tensor([ + 3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075, + 0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149, + -1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717, + 1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862 + ])[None] + latents_std = torch.tensor([ + 2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822, + 2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588, + 2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999, + 2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190 + ])[None] + class Mochi(LatentFormat): latent_channels = 12 latent_dimensions = 3 @@ -770,10 +817,6 @@ class Hunyuan3Dv2_1(LatentFormat): latent_channels = 64 latent_dimensions = 1 -class Trellis2(LatentFormat): - latent_channels = 32 - - class Hunyuan3Dv2mini(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/sd.py b/comfy/sd.py index c08688f42..cf8d35a7c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1114,6 +1114,15 @@ class VAE: pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples + def prepare_decode(self, sample_shape, memory_required=None): + """For VAEs whose real decode entry point bypasses decode()""" + if memory_required is None: + memory_required = self.memory_used_decode(sample_shape, self.vae_dtype) + memory_required = max(1, int(memory_required)) + model_management.load_models_gpu([self.patcher], memory_required=memory_required, force_full_load=self.disable_offload) + free_memory = self.patcher.get_free_memory(self.device) + return max(1, int(free_memory / memory_required)) + def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index b1bf6807c..bb948831f 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -5,6 +5,7 @@ from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_point from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch +import comfy.latent_formats import comfy.model_management import comfy.utils import folder_paths @@ -18,57 +19,12 @@ ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES") NAFModel = io.Custom("NAF_MODEL") -def prepare_trellis_vae_for_decode(vae, sample_shape): - memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype) - if len(sample_shape) == 5: - memory_required *= max(1, int(sample_shape[4])) - memory_required = max(1, int(memory_required)) - device = comfy.model_management.get_torch_device() - comfy.model_management.load_models_gpu( - [vae.patcher], - memory_required=memory_required, - force_full_load=getattr(vae, "disable_offload", False), - ) - free_memory = vae.patcher.get_free_memory(device) - batch_number = max(1, int(free_memory / memory_required)) - return batch_number - -shape_slat_normalization = { - "mean": torch.tensor([ - 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218, - -0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944, - 0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667, - -0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665 - ])[None], - "std": torch.tensor([ - 5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004, - 5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578, - 4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194, - 5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691 - ])[None] -} - -tex_slat_normalization = { - "mean": torch.tensor([ - 3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075, - 0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149, - -1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717, - 1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862 - ])[None], - "std": torch.tensor([ - 2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822, - 2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588, - 2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999, - 2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190 - ])[None] -} +shape_slat_format = comfy.latent_formats.Trellis2ShapeSLAT() +tex_slat_format = comfy.latent_formats.Trellis2TexSLAT() def shape_norm(shape_latent, coords): - std = shape_slat_normalization["std"].to(shape_latent) - mean = shape_slat_normalization["mean"].to(shape_latent) - samples = SparseTensor(feats = shape_latent, coords=coords) - samples = samples * std + mean - return samples + feats = shape_slat_format.process_out(shape_latent) + return SparseTensor(feats=feats, coords=coords) def infer_batched_coord_layout(coords): @@ -177,7 +133,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): sample_tensor = samples["samples"] device = comfy.model_management.get_torch_device() coords = samples["coords"] - prepare_trellis_vae_for_decode(vae, sample_tensor.shape) + vae.prepare_decode(sample_tensor.shape) trellis_vae = vae.first_stage_model coord_counts = samples.get("coord_counts") @@ -243,7 +199,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): sample_tensor = samples["samples"] device = comfy.model_management.get_torch_device() coords = samples["coords"] - prepare_trellis_vae_for_decode(vae, sample_tensor.shape) + vae.prepare_decode(sample_tensor.shape) trellis_vae = vae.first_stage_model coord_counts = samples.get("coord_counts") model_frame = samples.get("model_frame", "y_up") @@ -252,16 +208,13 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): samples = samples["samples"] samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts) samples = samples.to(device) - std = tex_slat_normalization["std"].to(samples) - mean = tex_slat_normalization["mean"].to(samples) - samples = SparseTensor(feats = samples, coords=coords.to(device)) - samples = samples * std + mean + feats = tex_slat_format.process_out(samples) + samples = SparseTensor(feats=feats, coords=coords.to(device)) voxel = trellis_vae.decode_tex_slat(samples.to(vae.vae_dtype), shape_subdivides) # Keep all decoded channels. The texture VAE emits 6: base_color (0:3), # metallic (3), roughness (4), alpha (5) — all in [0, 1]. Vertex-color - # consumers (PaintMesh) slice [:3]; BakeTextureFromVoxel uses the full - # PBR set. Older 3-channel checkpoints pass through unchanged. + # consumers (PaintMesh) slice [:3] color_feats = voxel.feats voxel_coords = voxel.coords @@ -316,7 +269,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): resolution = int(resolution) sample_tensor = samples["samples"] sample_tensor = sample_tensor[:, :8] - batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape) + batch_number = vae.prepare_decode(sample_tensor.shape) shape_vae = vae.first_stage_model load_device = comfy.model_management.get_torch_device() decoded_batches = [] @@ -381,7 +334,7 @@ class Trellis2UpsampleStage(IO.ComfyNode): @classmethod def execute(cls, positive, negative, shape_latent, vae, target_resolution, max_tokens): device = comfy.model_management.get_torch_device() - prepare_trellis_vae_for_decode(vae, shape_latent["samples"].shape) + vae.prepare_decode(shape_latent["samples"].shape) coord_counts = shape_latent.get("coord_counts") shape_vae = vae.first_stage_model