mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-24 16:59:29 +08:00
Cube3D: route VAE decode through managed comfy.sd.VAE.decode
Stop fighting ComfyUI's model management. VAEDecodeCube was manually calling load_models_gpu + .to(vae.device) and the VAE forced disable_offload=True because it bypassed the managed decode path. Now CubeShapeVAE.decode(samples) is the entry point that comfy.sd.VAE.decode calls, so loading/device/dtype are handled automatically (like Hunyuan3Dv2): - removed disable_offload=True (let the offload system manage weights) - removed manual load_models_gpu + .to(device) from the node - process_output set to identity (default clamps [0,1] in-place and would destroy the occupancy isosurface) - decode() pre-inverts VAE.decode's trailing movedim(1,-1) so the node receives grid logits unchanged (parity preserved) - memory_used_decode sized by num_tokens (shape[-1]) for the new latent layout Amp-Thread-ID: https://ampcode.com/threads/T-019ec361-addb-70d8-a74b-438ce8a1e096 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
a6c7397b71
commit
aeb3c77ae9
@ -288,6 +288,9 @@ def generate_dense_grid_points(bbox_min, bbox_max, resolution_base, indexing="ij
|
||||
class CubeShapeVAE(nn.Module):
|
||||
"""Decode-only OneDAutoEncoder. Encoder weights load with strict=False (ignored)."""
|
||||
|
||||
# Fixed query bounds for the occupancy grid (upstream default).
|
||||
decode_bounds = (-1.05, -1.05, -1.05, 1.05, 1.05, 1.05)
|
||||
|
||||
def __init__(self, num_encoder_latents=1024, embed_dim=32, width=768, num_heads=12,
|
||||
num_freqs=128, num_decoder_layers=24, num_codes=16384, out_dim=1, eps=1e-6,
|
||||
dtype=None, device=None):
|
||||
@ -303,6 +306,19 @@ class CubeShapeVAE(nn.Module):
|
||||
self.occupancy_decoder = OneDOccupancyDecoder(self.embedder, out_dim, width, num_heads,
|
||||
eps=eps, dtype=dtype, device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, samples, resolution_base=8.0, chunk_size=100_000, **kwargs):
|
||||
"""Token IDs -> occupancy grid logits. Entry point for comfy.sd.VAE.decode, which
|
||||
manages model loading/device/dtype. `samples` arrive as (B, 1, num_tokens) in the
|
||||
VAE working dtype on the load device. VAE.decode applies a trailing movedim(1, -1),
|
||||
so pre-invert it here to hand the node grid logits as (B, gx, gy, gz)."""
|
||||
ids = samples.reshape(samples.shape[0], -1)[:, :self.cfg_num_encoder_latents]
|
||||
ids = ids.round().long().clamp(0, self.cfg_num_codes - 1)
|
||||
latents = self.decode_indices(ids)
|
||||
grid_logits, _, _, _ = self.extract_geometry(
|
||||
latents, bounds=self.decode_bounds, resolution_base=resolution_base, chunk_size=chunk_size)
|
||||
return grid_logits.movedim(-1, 1)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_indices(self, shape_ids):
|
||||
z_q = self.bottleneck.block.lookup_codebook(shape_ids)
|
||||
|
||||
12
comfy/sd.py
12
comfy/sd.py
@ -783,10 +783,6 @@ class VAE:
|
||||
elif "bottleneck.block.codebook.weight" in sd:
|
||||
self.cube3d = True
|
||||
self.latent_dim = 1
|
||||
# VAEDecodeCube calls first_stage_model.decode_indices/extract_geometry
|
||||
# directly (not through the patcher-managed forward), so the weights must
|
||||
# be fully resident on-device. Disable dynamic streaming offload.
|
||||
self.disable_offload = True
|
||||
embed_dim = sd["bottleneck.block.codebook.weight"].shape[1]
|
||||
num_codes = sd["bottleneck.block.codebook.weight"].shape[0]
|
||||
width = sd["bottleneck.block.c_out.weight"].shape[0]
|
||||
@ -800,7 +796,13 @@ class VAE:
|
||||
num_heads=num_heads, num_freqs=num_freqs, num_decoder_layers=num_decoder_layers,
|
||||
num_codes=num_codes,
|
||||
)
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[1] * 768) * model_management.dtype_size(dtype)
|
||||
# Decode goes through the managed comfy.sd.VAE.decode path; the grid logits
|
||||
# are float32 regardless of weight dtype, so keep process_output identity
|
||||
# (the default clamps to [0, 1] in-place and would destroy the isosurface).
|
||||
self.process_output = lambda image: image
|
||||
self.process_input = lambda image: image
|
||||
# shape is the token-ID latent (B, 1, num_tokens); size by num_tokens.
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[-1] * 768) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.float32]
|
||||
|
||||
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
||||
|
||||
@ -121,15 +121,15 @@ class VAEDecodeCube(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, samples, resolution_base, chunk_size) -> IO.NodeOutput:
|
||||
comfy.model_management.load_models_gpu([vae.patcher])
|
||||
tok = vae.first_stage_model
|
||||
ids = samples["samples"]
|
||||
ids = ids.reshape(ids.shape[0], -1)[:, :tok.cfg_num_encoder_latents].long()
|
||||
ids = ids.clamp(0, tok.cfg_num_codes - 1).to(vae.device)
|
||||
# Managed decode: comfy.sd.VAE.decode handles model loading + device/dtype and
|
||||
# returns the occupancy grid logits (B, gx, gy, gz). Marching cubes runs here.
|
||||
grid = vae.decode(samples["samples"],
|
||||
vae_options={"resolution_base": resolution_base, "chunk_size": chunk_size})
|
||||
|
||||
latents = tok.decode_indices(ids)
|
||||
grid, grid_size, bbox_size, bbox_min = tok.extract_geometry(
|
||||
latents, resolution_base=resolution_base, chunk_size=chunk_size)
|
||||
bounds = vae.first_stage_model.decode_bounds
|
||||
bbox_min = np.array(bounds[0:3])
|
||||
bbox_size = np.array(bounds[3:6]) - bbox_min
|
||||
grid_size = list(grid.shape[1:])
|
||||
|
||||
verts_list, faces_list = [], []
|
||||
for i in range(grid.shape[0]):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user