mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
Cleanup VAE code some
This commit is contained in:
parent
41f5f4b2c0
commit
36e8f62dd4
@ -249,6 +249,53 @@ class TripoSplat(LatentFormat):
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
class Trellis2(LatentFormat):
|
||||
latent_channels = 32
|
||||
|
||||
class Trellis2SLAT(Trellis2):
|
||||
# Sparse structured latent: per-token feats [N, 32]. process_out denormalizes
|
||||
# the decoded feats (latent * std + mean); subclasses carry each space's stats.
|
||||
latents_mean = None
|
||||
latents_std = None
|
||||
|
||||
def process_in(self, latent):
|
||||
mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return (latent - mean) / std
|
||||
|
||||
def process_out(self, latent):
|
||||
mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
std = self.latents_std.to(latent.device, latent.dtype)
|
||||
return latent * std + mean
|
||||
|
||||
class Trellis2ShapeSLAT(Trellis2SLAT):
|
||||
latents_mean = torch.tensor([
|
||||
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
|
||||
-0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
|
||||
0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
|
||||
-0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
|
||||
])[None]
|
||||
latents_std = torch.tensor([
|
||||
5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
|
||||
5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
|
||||
4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
|
||||
5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
|
||||
])[None]
|
||||
|
||||
class Trellis2TexSLAT(Trellis2SLAT):
|
||||
latents_mean = torch.tensor([
|
||||
3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075,
|
||||
0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149,
|
||||
-1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717,
|
||||
1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862
|
||||
])[None]
|
||||
latents_std = torch.tensor([
|
||||
2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822,
|
||||
2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588,
|
||||
2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999,
|
||||
2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190
|
||||
])[None]
|
||||
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
latent_dimensions = 3
|
||||
@ -770,10 +817,6 @@ class Hunyuan3Dv2_1(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
class Trellis2(LatentFormat):
|
||||
latent_channels = 32
|
||||
|
||||
|
||||
class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
@ -1114,6 +1114,15 @@ class VAE:
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def prepare_decode(self, sample_shape, memory_required=None):
|
||||
"""For VAEs whose real decode entry point bypasses decode()"""
|
||||
if memory_required is None:
|
||||
memory_required = self.memory_used_decode(sample_shape, self.vae_dtype)
|
||||
memory_required = max(1, int(memory_required))
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_required, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
return max(1, int(free_memory / memory_required))
|
||||
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
self.throw_exception_if_invalid()
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||
|
||||
@ -5,6 +5,7 @@ from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_point
|
||||
from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict
|
||||
|
||||
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
|
||||
import comfy.latent_formats
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
@ -18,57 +19,12 @@ ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
|
||||
NAFModel = io.Custom("NAF_MODEL")
|
||||
|
||||
|
||||
def prepare_trellis_vae_for_decode(vae, sample_shape):
|
||||
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
|
||||
if len(sample_shape) == 5:
|
||||
memory_required *= max(1, int(sample_shape[4]))
|
||||
memory_required = max(1, int(memory_required))
|
||||
device = comfy.model_management.get_torch_device()
|
||||
comfy.model_management.load_models_gpu(
|
||||
[vae.patcher],
|
||||
memory_required=memory_required,
|
||||
force_full_load=getattr(vae, "disable_offload", False),
|
||||
)
|
||||
free_memory = vae.patcher.get_free_memory(device)
|
||||
batch_number = max(1, int(free_memory / memory_required))
|
||||
return batch_number
|
||||
|
||||
shape_slat_normalization = {
|
||||
"mean": torch.tensor([
|
||||
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
|
||||
-0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
|
||||
0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
|
||||
-0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
|
||||
])[None],
|
||||
"std": torch.tensor([
|
||||
5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
|
||||
5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
|
||||
4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
|
||||
5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
|
||||
])[None]
|
||||
}
|
||||
|
||||
tex_slat_normalization = {
|
||||
"mean": torch.tensor([
|
||||
3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075,
|
||||
0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149,
|
||||
-1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717,
|
||||
1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862
|
||||
])[None],
|
||||
"std": torch.tensor([
|
||||
2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822,
|
||||
2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588,
|
||||
2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999,
|
||||
2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190
|
||||
])[None]
|
||||
}
|
||||
shape_slat_format = comfy.latent_formats.Trellis2ShapeSLAT()
|
||||
tex_slat_format = comfy.latent_formats.Trellis2TexSLAT()
|
||||
|
||||
def shape_norm(shape_latent, coords):
|
||||
std = shape_slat_normalization["std"].to(shape_latent)
|
||||
mean = shape_slat_normalization["mean"].to(shape_latent)
|
||||
samples = SparseTensor(feats = shape_latent, coords=coords)
|
||||
samples = samples * std + mean
|
||||
return samples
|
||||
feats = shape_slat_format.process_out(shape_latent)
|
||||
return SparseTensor(feats=feats, coords=coords)
|
||||
|
||||
|
||||
def infer_batched_coord_layout(coords):
|
||||
@ -177,7 +133,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
sample_tensor = samples["samples"]
|
||||
device = comfy.model_management.get_torch_device()
|
||||
coords = samples["coords"]
|
||||
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||
vae.prepare_decode(sample_tensor.shape)
|
||||
trellis_vae = vae.first_stage_model
|
||||
coord_counts = samples.get("coord_counts")
|
||||
|
||||
@ -243,7 +199,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
sample_tensor = samples["samples"]
|
||||
device = comfy.model_management.get_torch_device()
|
||||
coords = samples["coords"]
|
||||
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||
vae.prepare_decode(sample_tensor.shape)
|
||||
trellis_vae = vae.first_stage_model
|
||||
coord_counts = samples.get("coord_counts")
|
||||
model_frame = samples.get("model_frame", "y_up")
|
||||
@ -252,16 +208,13 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
samples = samples["samples"]
|
||||
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
|
||||
samples = samples.to(device)
|
||||
std = tex_slat_normalization["std"].to(samples)
|
||||
mean = tex_slat_normalization["mean"].to(samples)
|
||||
samples = SparseTensor(feats = samples, coords=coords.to(device))
|
||||
samples = samples * std + mean
|
||||
feats = tex_slat_format.process_out(samples)
|
||||
samples = SparseTensor(feats=feats, coords=coords.to(device))
|
||||
|
||||
voxel = trellis_vae.decode_tex_slat(samples.to(vae.vae_dtype), shape_subdivides)
|
||||
# Keep all decoded channels. The texture VAE emits 6: base_color (0:3),
|
||||
# metallic (3), roughness (4), alpha (5) — all in [0, 1]. Vertex-color
|
||||
# consumers (PaintMesh) slice [:3]; BakeTextureFromVoxel uses the full
|
||||
# PBR set. Older 3-channel checkpoints pass through unchanged.
|
||||
# consumers (PaintMesh) slice [:3]
|
||||
color_feats = voxel.feats
|
||||
voxel_coords = voxel.coords
|
||||
|
||||
@ -316,7 +269,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
resolution = int(resolution)
|
||||
sample_tensor = samples["samples"]
|
||||
sample_tensor = sample_tensor[:, :8]
|
||||
batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||
batch_number = vae.prepare_decode(sample_tensor.shape)
|
||||
shape_vae = vae.first_stage_model
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
decoded_batches = []
|
||||
@ -381,7 +334,7 @@ class Trellis2UpsampleStage(IO.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, shape_latent, vae, target_resolution, max_tokens):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
prepare_trellis_vae_for_decode(vae, shape_latent["samples"].shape)
|
||||
vae.prepare_decode(shape_latent["samples"].shape)
|
||||
|
||||
coord_counts = shape_latent.get("coord_counts")
|
||||
shape_vae = vae.first_stage_model
|
||||
|
||||
Loading…
Reference in New Issue
Block a user