Cube3D: route VAE decode through managed comfy.sd.VAE.decode

Stop fighting ComfyUI's model management. VAEDecodeCube was manually
calling load_models_gpu + .to(vae.device) and the VAE forced
disable_offload=True because it bypassed the managed decode path.

Now CubeShapeVAE.decode(samples) is the entry point that comfy.sd.VAE.decode
calls, so loading/device/dtype are handled automatically (like Hunyuan3Dv2):
- removed disable_offload=True (let the offload system manage weights)
- removed manual load_models_gpu + .to(device) from the node
- process_output set to identity (default clamps [0,1] in-place and would
  destroy the occupancy isosurface)
- decode() pre-inverts VAE.decode's trailing movedim(1,-1) so the node
  receives grid logits unchanged (parity preserved)
- memory_used_decode sized by num_tokens (shape[-1]) for the new latent layout

Amp-Thread-ID: https://ampcode.com/threads/T-019ec361-addb-70d8-a74b-438ce8a1e096
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Jedrzej Kosinski 2026-06-14 23:28:22 -07:00
parent a6c7397b71
commit aeb3c77ae9
3 changed files with 31 additions and 13 deletions

View File

@ -288,6 +288,9 @@ def generate_dense_grid_points(bbox_min, bbox_max, resolution_base, indexing="ij
class CubeShapeVAE(nn.Module):
"""Decode-only OneDAutoEncoder. Encoder weights load with strict=False (ignored)."""
# Fixed query bounds for the occupancy grid (upstream default).
decode_bounds = (-1.05, -1.05, -1.05, 1.05, 1.05, 1.05)
def __init__(self, num_encoder_latents=1024, embed_dim=32, width=768, num_heads=12,
num_freqs=128, num_decoder_layers=24, num_codes=16384, out_dim=1, eps=1e-6,
dtype=None, device=None):
@ -303,6 +306,19 @@ class CubeShapeVAE(nn.Module):
self.occupancy_decoder = OneDOccupancyDecoder(self.embedder, out_dim, width, num_heads,
eps=eps, dtype=dtype, device=device)
@torch.no_grad()
def decode(self, samples, resolution_base=8.0, chunk_size=100_000, **kwargs):
"""Token IDs -> occupancy grid logits. Entry point for comfy.sd.VAE.decode, which
manages model loading/device/dtype. `samples` arrive as (B, 1, num_tokens) in the
VAE working dtype on the load device. VAE.decode applies a trailing movedim(1, -1),
so pre-invert it here to hand the node grid logits as (B, gx, gy, gz)."""
ids = samples.reshape(samples.shape[0], -1)[:, :self.cfg_num_encoder_latents]
ids = ids.round().long().clamp(0, self.cfg_num_codes - 1)
latents = self.decode_indices(ids)
grid_logits, _, _, _ = self.extract_geometry(
latents, bounds=self.decode_bounds, resolution_base=resolution_base, chunk_size=chunk_size)
return grid_logits.movedim(-1, 1)
@torch.no_grad()
def decode_indices(self, shape_ids):
z_q = self.bottleneck.block.lookup_codebook(shape_ids)

View File

@ -783,10 +783,6 @@ class VAE:
elif "bottleneck.block.codebook.weight" in sd:
self.cube3d = True
self.latent_dim = 1
# VAEDecodeCube calls first_stage_model.decode_indices/extract_geometry
# directly (not through the patcher-managed forward), so the weights must
# be fully resident on-device. Disable dynamic streaming offload.
self.disable_offload = True
embed_dim = sd["bottleneck.block.codebook.weight"].shape[1]
num_codes = sd["bottleneck.block.codebook.weight"].shape[0]
width = sd["bottleneck.block.c_out.weight"].shape[0]
@ -800,7 +796,13 @@ class VAE:
num_heads=num_heads, num_freqs=num_freqs, num_decoder_layers=num_decoder_layers,
num_codes=num_codes,
)
self.memory_used_decode = lambda shape, dtype: (1000 * shape[1] * 768) * model_management.dtype_size(dtype)
# Decode goes through the managed comfy.sd.VAE.decode path; the grid logits
# are float32 regardless of weight dtype, so keep process_output identity
# (the default clamps to [0, 1] in-place and would destroy the isosurface).
self.process_output = lambda image: image
self.process_input = lambda image: image
# shape is the token-ID latent (B, 1, num_tokens); size by num_tokens.
self.memory_used_decode = lambda shape, dtype: (1000 * shape[-1] * 768) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.float32]
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio

View File

@ -121,15 +121,15 @@ class VAEDecodeCube(IO.ComfyNode):
@classmethod
def execute(cls, vae, samples, resolution_base, chunk_size) -> IO.NodeOutput:
comfy.model_management.load_models_gpu([vae.patcher])
tok = vae.first_stage_model
ids = samples["samples"]
ids = ids.reshape(ids.shape[0], -1)[:, :tok.cfg_num_encoder_latents].long()
ids = ids.clamp(0, tok.cfg_num_codes - 1).to(vae.device)
# Managed decode: comfy.sd.VAE.decode handles model loading + device/dtype and
# returns the occupancy grid logits (B, gx, gy, gz). Marching cubes runs here.
grid = vae.decode(samples["samples"],
vae_options={"resolution_base": resolution_base, "chunk_size": chunk_size})
latents = tok.decode_indices(ids)
grid, grid_size, bbox_size, bbox_min = tok.extract_geometry(
latents, resolution_base=resolution_base, chunk_size=chunk_size)
bounds = vae.first_stage_model.decode_bounds
bbox_min = np.array(bounds[0:3])
bbox_size = np.array(bounds[3:6]) - bbox_min
grid_size = list(grid.shape[1:])
verts_list, faces_list = [], []
for i in range(grid.shape[0]):