mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-25 01:09:24 +08:00
Cube3D: route VAE decode through managed comfy.sd.VAE.decode
Stop fighting ComfyUI's model management. VAEDecodeCube was manually calling load_models_gpu + .to(vae.device) and the VAE forced disable_offload=True because it bypassed the managed decode path. Now CubeShapeVAE.decode(samples) is the entry point that comfy.sd.VAE.decode calls, so loading/device/dtype are handled automatically (like Hunyuan3Dv2): - removed disable_offload=True (let the offload system manage weights) - removed manual load_models_gpu + .to(device) from the node - process_output set to identity (default clamps [0,1] in-place and would destroy the occupancy isosurface) - decode() pre-inverts VAE.decode's trailing movedim(1,-1) so the node receives grid logits unchanged (parity preserved) - memory_used_decode sized by num_tokens (shape[-1]) for the new latent layout Amp-Thread-ID: https://ampcode.com/threads/T-019ec361-addb-70d8-a74b-438ce8a1e096 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
a6c7397b71
commit
aeb3c77ae9
@ -288,6 +288,9 @@ def generate_dense_grid_points(bbox_min, bbox_max, resolution_base, indexing="ij
|
|||||||
class CubeShapeVAE(nn.Module):
|
class CubeShapeVAE(nn.Module):
|
||||||
"""Decode-only OneDAutoEncoder. Encoder weights load with strict=False (ignored)."""
|
"""Decode-only OneDAutoEncoder. Encoder weights load with strict=False (ignored)."""
|
||||||
|
|
||||||
|
# Fixed query bounds for the occupancy grid (upstream default).
|
||||||
|
decode_bounds = (-1.05, -1.05, -1.05, 1.05, 1.05, 1.05)
|
||||||
|
|
||||||
def __init__(self, num_encoder_latents=1024, embed_dim=32, width=768, num_heads=12,
|
def __init__(self, num_encoder_latents=1024, embed_dim=32, width=768, num_heads=12,
|
||||||
num_freqs=128, num_decoder_layers=24, num_codes=16384, out_dim=1, eps=1e-6,
|
num_freqs=128, num_decoder_layers=24, num_codes=16384, out_dim=1, eps=1e-6,
|
||||||
dtype=None, device=None):
|
dtype=None, device=None):
|
||||||
@ -303,6 +306,19 @@ class CubeShapeVAE(nn.Module):
|
|||||||
self.occupancy_decoder = OneDOccupancyDecoder(self.embedder, out_dim, width, num_heads,
|
self.occupancy_decoder = OneDOccupancyDecoder(self.embedder, out_dim, width, num_heads,
|
||||||
eps=eps, dtype=dtype, device=device)
|
eps=eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode(self, samples, resolution_base=8.0, chunk_size=100_000, **kwargs):
|
||||||
|
"""Token IDs -> occupancy grid logits. Entry point for comfy.sd.VAE.decode, which
|
||||||
|
manages model loading/device/dtype. `samples` arrive as (B, 1, num_tokens) in the
|
||||||
|
VAE working dtype on the load device. VAE.decode applies a trailing movedim(1, -1),
|
||||||
|
so pre-invert it here to hand the node grid logits as (B, gx, gy, gz)."""
|
||||||
|
ids = samples.reshape(samples.shape[0], -1)[:, :self.cfg_num_encoder_latents]
|
||||||
|
ids = ids.round().long().clamp(0, self.cfg_num_codes - 1)
|
||||||
|
latents = self.decode_indices(ids)
|
||||||
|
grid_logits, _, _, _ = self.extract_geometry(
|
||||||
|
latents, bounds=self.decode_bounds, resolution_base=resolution_base, chunk_size=chunk_size)
|
||||||
|
return grid_logits.movedim(-1, 1)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode_indices(self, shape_ids):
|
def decode_indices(self, shape_ids):
|
||||||
z_q = self.bottleneck.block.lookup_codebook(shape_ids)
|
z_q = self.bottleneck.block.lookup_codebook(shape_ids)
|
||||||
|
|||||||
12
comfy/sd.py
12
comfy/sd.py
@ -783,10 +783,6 @@ class VAE:
|
|||||||
elif "bottleneck.block.codebook.weight" in sd:
|
elif "bottleneck.block.codebook.weight" in sd:
|
||||||
self.cube3d = True
|
self.cube3d = True
|
||||||
self.latent_dim = 1
|
self.latent_dim = 1
|
||||||
# VAEDecodeCube calls first_stage_model.decode_indices/extract_geometry
|
|
||||||
# directly (not through the patcher-managed forward), so the weights must
|
|
||||||
# be fully resident on-device. Disable dynamic streaming offload.
|
|
||||||
self.disable_offload = True
|
|
||||||
embed_dim = sd["bottleneck.block.codebook.weight"].shape[1]
|
embed_dim = sd["bottleneck.block.codebook.weight"].shape[1]
|
||||||
num_codes = sd["bottleneck.block.codebook.weight"].shape[0]
|
num_codes = sd["bottleneck.block.codebook.weight"].shape[0]
|
||||||
width = sd["bottleneck.block.c_out.weight"].shape[0]
|
width = sd["bottleneck.block.c_out.weight"].shape[0]
|
||||||
@ -800,7 +796,13 @@ class VAE:
|
|||||||
num_heads=num_heads, num_freqs=num_freqs, num_decoder_layers=num_decoder_layers,
|
num_heads=num_heads, num_freqs=num_freqs, num_decoder_layers=num_decoder_layers,
|
||||||
num_codes=num_codes,
|
num_codes=num_codes,
|
||||||
)
|
)
|
||||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[1] * 768) * model_management.dtype_size(dtype)
|
# Decode goes through the managed comfy.sd.VAE.decode path; the grid logits
|
||||||
|
# are float32 regardless of weight dtype, so keep process_output identity
|
||||||
|
# (the default clamps to [0, 1] in-place and would destroy the isosurface).
|
||||||
|
self.process_output = lambda image: image
|
||||||
|
self.process_input = lambda image: image
|
||||||
|
# shape is the token-ID latent (B, 1, num_tokens); size by num_tokens.
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (1000 * shape[-1] * 768) * model_management.dtype_size(dtype)
|
||||||
self.working_dtypes = [torch.float32]
|
self.working_dtypes = [torch.float32]
|
||||||
|
|
||||||
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
||||||
|
|||||||
@ -121,15 +121,15 @@ class VAEDecodeCube(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, vae, samples, resolution_base, chunk_size) -> IO.NodeOutput:
|
def execute(cls, vae, samples, resolution_base, chunk_size) -> IO.NodeOutput:
|
||||||
comfy.model_management.load_models_gpu([vae.patcher])
|
# Managed decode: comfy.sd.VAE.decode handles model loading + device/dtype and
|
||||||
tok = vae.first_stage_model
|
# returns the occupancy grid logits (B, gx, gy, gz). Marching cubes runs here.
|
||||||
ids = samples["samples"]
|
grid = vae.decode(samples["samples"],
|
||||||
ids = ids.reshape(ids.shape[0], -1)[:, :tok.cfg_num_encoder_latents].long()
|
vae_options={"resolution_base": resolution_base, "chunk_size": chunk_size})
|
||||||
ids = ids.clamp(0, tok.cfg_num_codes - 1).to(vae.device)
|
|
||||||
|
|
||||||
latents = tok.decode_indices(ids)
|
bounds = vae.first_stage_model.decode_bounds
|
||||||
grid, grid_size, bbox_size, bbox_min = tok.extract_geometry(
|
bbox_min = np.array(bounds[0:3])
|
||||||
latents, resolution_base=resolution_base, chunk_size=chunk_size)
|
bbox_size = np.array(bounds[3:6]) - bbox_min
|
||||||
|
grid_size = list(grid.shape[1:])
|
||||||
|
|
||||||
verts_list, faces_list = [], []
|
verts_list, faces_list = [], []
|
||||||
for i in range(grid.shape[0]):
|
for i in range(grid.shape[0]):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user