From aeb3c77ae9ec2fa000ff28a4108a5f0def82a641 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 14 Jun 2026 23:28:22 -0700 Subject: [PATCH] Cube3D: route VAE decode through managed comfy.sd.VAE.decode Stop fighting ComfyUI's model management. VAEDecodeCube was manually calling load_models_gpu + .to(vae.device) and the VAE forced disable_offload=True because it bypassed the managed decode path. Now CubeShapeVAE.decode(samples) is the entry point that comfy.sd.VAE.decode calls, so loading/device/dtype are handled automatically (like Hunyuan3Dv2): - removed disable_offload=True (let the offload system manage weights) - removed manual load_models_gpu + .to(device) from the node - process_output set to identity (default clamps [0,1] in-place and would destroy the occupancy isosurface) - decode() pre-inverts VAE.decode's trailing movedim(1,-1) so the node receives grid logits unchanged (parity preserved) - memory_used_decode sized by num_tokens (shape[-1]) for the new latent layout Amp-Thread-ID: https://ampcode.com/threads/T-019ec361-addb-70d8-a74b-438ce8a1e096 Co-authored-by: Amp --- comfy/ldm/cube/vae.py | 16 ++++++++++++++++ comfy/sd.py | 12 +++++++----- comfy_extras/nodes_cube.py | 16 ++++++++-------- 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/comfy/ldm/cube/vae.py b/comfy/ldm/cube/vae.py index 001741dac..201dc73d5 100644 --- a/comfy/ldm/cube/vae.py +++ b/comfy/ldm/cube/vae.py @@ -288,6 +288,9 @@ def generate_dense_grid_points(bbox_min, bbox_max, resolution_base, indexing="ij class CubeShapeVAE(nn.Module): """Decode-only OneDAutoEncoder. Encoder weights load with strict=False (ignored).""" + # Fixed query bounds for the occupancy grid (upstream default). + decode_bounds = (-1.05, -1.05, -1.05, 1.05, 1.05, 1.05) + def __init__(self, num_encoder_latents=1024, embed_dim=32, width=768, num_heads=12, num_freqs=128, num_decoder_layers=24, num_codes=16384, out_dim=1, eps=1e-6, dtype=None, device=None): @@ -303,6 +306,19 @@ class CubeShapeVAE(nn.Module): self.occupancy_decoder = OneDOccupancyDecoder(self.embedder, out_dim, width, num_heads, eps=eps, dtype=dtype, device=device) + @torch.no_grad() + def decode(self, samples, resolution_base=8.0, chunk_size=100_000, **kwargs): + """Token IDs -> occupancy grid logits. Entry point for comfy.sd.VAE.decode, which + manages model loading/device/dtype. `samples` arrive as (B, 1, num_tokens) in the + VAE working dtype on the load device. VAE.decode applies a trailing movedim(1, -1), + so pre-invert it here to hand the node grid logits as (B, gx, gy, gz).""" + ids = samples.reshape(samples.shape[0], -1)[:, :self.cfg_num_encoder_latents] + ids = ids.round().long().clamp(0, self.cfg_num_codes - 1) + latents = self.decode_indices(ids) + grid_logits, _, _, _ = self.extract_geometry( + latents, bounds=self.decode_bounds, resolution_base=resolution_base, chunk_size=chunk_size) + return grid_logits.movedim(-1, 1) + @torch.no_grad() def decode_indices(self, shape_ids): z_q = self.bottleneck.block.lookup_codebook(shape_ids) diff --git a/comfy/sd.py b/comfy/sd.py index dcafb87b9..c2ececeeb 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -783,10 +783,6 @@ class VAE: elif "bottleneck.block.codebook.weight" in sd: self.cube3d = True self.latent_dim = 1 - # VAEDecodeCube calls first_stage_model.decode_indices/extract_geometry - # directly (not through the patcher-managed forward), so the weights must - # be fully resident on-device. Disable dynamic streaming offload. - self.disable_offload = True embed_dim = sd["bottleneck.block.codebook.weight"].shape[1] num_codes = sd["bottleneck.block.codebook.weight"].shape[0] width = sd["bottleneck.block.c_out.weight"].shape[0] @@ -800,7 +796,13 @@ class VAE: num_heads=num_heads, num_freqs=num_freqs, num_decoder_layers=num_decoder_layers, num_codes=num_codes, ) - self.memory_used_decode = lambda shape, dtype: (1000 * shape[1] * 768) * model_management.dtype_size(dtype) + # Decode goes through the managed comfy.sd.VAE.decode path; the grid logits + # are float32 regardless of weight dtype, so keep process_output identity + # (the default clamps to [0, 1] in-place and would destroy the isosurface). + self.process_output = lambda image: image + self.process_input = lambda image: image + # shape is the token-ID latent (B, 1, num_tokens); size by num_tokens. + self.memory_used_decode = lambda shape, dtype: (1000 * shape[-1] * 768) * model_management.dtype_size(dtype) self.working_dtypes = [torch.float32] elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio diff --git a/comfy_extras/nodes_cube.py b/comfy_extras/nodes_cube.py index b10813805..bab879ec6 100644 --- a/comfy_extras/nodes_cube.py +++ b/comfy_extras/nodes_cube.py @@ -121,15 +121,15 @@ class VAEDecodeCube(IO.ComfyNode): @classmethod def execute(cls, vae, samples, resolution_base, chunk_size) -> IO.NodeOutput: - comfy.model_management.load_models_gpu([vae.patcher]) - tok = vae.first_stage_model - ids = samples["samples"] - ids = ids.reshape(ids.shape[0], -1)[:, :tok.cfg_num_encoder_latents].long() - ids = ids.clamp(0, tok.cfg_num_codes - 1).to(vae.device) + # Managed decode: comfy.sd.VAE.decode handles model loading + device/dtype and + # returns the occupancy grid logits (B, gx, gy, gz). Marching cubes runs here. + grid = vae.decode(samples["samples"], + vae_options={"resolution_base": resolution_base, "chunk_size": chunk_size}) - latents = tok.decode_indices(ids) - grid, grid_size, bbox_size, bbox_min = tok.extract_geometry( - latents, resolution_base=resolution_base, chunk_size=chunk_size) + bounds = vae.first_stage_model.decode_bounds + bbox_min = np.array(bounds[0:3]) + bbox_size = np.array(bounds[3:6]) - bbox_min + grid_size = list(grid.shape[1:]) verts_list, faces_list = [], [] for i in range(grid.shape[0]):