Cube3D: fix graph integration (3D latent, VAE device, fp32 cond, scikit-image)

Amp-Thread-ID: https://ampcode.com/threads/T-019ec361-addb-70d8-a74b-438ce8a1e096
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Jedrzej Kosinski 2026-06-14 22:59:11 -07:00
parent 01a8783bee
commit 871f7bc390
4 changed files with 20 additions and 7 deletions

View File

@ -2016,12 +2016,16 @@ def sample_cube(model, x, sigmas, extra_args=None, callback=None, disable=None,
bbox = torch.zeros((c.shape[0], 3), device=device, dtype=c.dtype)
return torch.cat([c, cube.bbox_proj(bbox).unsqueeze(1)], dim=1)
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
cond = add_bbox(cube.encode_text(pos.to(device=device, dtype=weight_dtype)))
if use_cfg:
ucond = add_bbox(cube.encode_text(neg.to(device=device, dtype=weight_dtype)))
cond = torch.cat([cond, ucond], dim=0)
# Conditioning (text_proj + bbox_proj) is computed in the model's weight dtype
# OUTSIDE the bf16 autocast block, matching upstream cube's Engine.prepare_inputs
# (run_clip/encode_text run in full precision). The autocast only covers the
# autoregressive transformer forward, exactly like Engine.run_gpt.
cond = add_bbox(cube.encode_text(pos.to(device=device, dtype=weight_dtype)))
if use_cfg:
ucond = add_bbox(cube.encode_text(neg.to(device=device, dtype=weight_dtype)))
cond = torch.cat([cond, ucond], dim=0)
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
bos = torch.full((cond.shape[0], 1), cube.shape_bos_id, dtype=torch.long, device=device)
embed = cube.encode_token(bos)
Bp, input_seq_len, dim = embed.shape

View File

@ -783,6 +783,10 @@ class VAE:
elif "bottleneck.block.codebook.weight" in sd:
self.cube3d = True
self.latent_dim = 1
# VAEDecodeCube calls first_stage_model.decode_indices/extract_geometry
# directly (not through the patcher-managed forward), so the weights must
# be fully resident on-device. Disable dynamic streaming offload.
self.disable_offload = True
embed_dim = sd["bottleneck.block.codebook.weight"].shape[1]
num_codes = sd["bottleneck.block.codebook.weight"].shape[0]
width = sd["bottleneck.block.c_out.weight"].shape[0]

View File

@ -38,7 +38,10 @@ class EmptyCubeLatent(IO.ComfyNode):
@classmethod
def execute(cls, num_tokens, batch_size) -> IO.NodeOutput:
latent = torch.zeros([batch_size, num_tokens], device=comfy.model_management.intermediate_device())
# Trailing singleton dim keeps this a 3D latent so it flows through ComfyUI's
# conds/noise pipeline (encode_model_conds reads noise.shape[2]); the sampler
# only uses dim 1 (num_tokens).
latent = torch.zeros([batch_size, num_tokens, 1], device=comfy.model_management.intermediate_device())
return IO.NodeOutput({"samples": latent, "type": "cube_tokens"})
@ -121,7 +124,8 @@ class VAEDecodeCube(IO.ComfyNode):
def execute(cls, vae, samples, resolution_base, chunk_size) -> IO.NodeOutput:
comfy.model_management.load_models_gpu([vae.patcher])
tok = vae.first_stage_model
ids = samples["samples"][:, :tok.cfg_num_encoder_latents].long()
ids = samples["samples"]
ids = ids.reshape(ids.shape[0], -1)[:, :tok.cfg_num_encoder_latents].long()
ids = ids.clamp(0, tok.cfg_num_codes - 1).to(vae.device)
latents = tok.decode_indices(ids)

View File

@ -31,6 +31,7 @@ blake3
#non essential dependencies:
kornia>=0.7.1
spandrel
scikit-image # marching cubes for Cube3D (VAEDecodeCube)
pydantic~=2.0
pydantic-settings~=2.0
PyOpenGL