mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 16:29:25 +08:00
Add native Roblox Cube3D text-to-3D support
Cube3D is an autoregressive VQ-token shape model (DualStreamRoformer) plus a VQ-VAE shape tokenizer (OneDAutoEncoder), not a diffusion model. It is wired natively following the Causal-WAN AR-video pattern: the GPT loads as a normal MODEL and generation runs through a dedicated 'cube' sampler instead of KSampler. - comfy/ldm/cube/gpt.py: DualStreamRoformer port (dual-stream RoPE attention, per-head RMSNorm, SwiGLU, KV cache; rope_theta=10000). - comfy/ldm/cube/vae.py: OneDAutoEncoder decode path (codebook lookup, decoder, occupancy decoder, dense-grid extraction + skimage marching cubes). - model_detection/supported_models/model_base: register shape_gpt as Cube3D MODEL (dims inferred from state dict; apply_model guarded to point at SamplerCube). - sd.py: detect shape_tokenizer and build CubeShapeVAE. - k_diffusion/sampling.py: sample_cube autoregressive sampler (decaying CFG + optional top-p), faithful to upstream Engine.run_gpt. - comfy_extras/nodes_cube.py: EmptyCubeLatent, CubeCodebookPatch (inject VQ codebook into wte), SamplerCube, VAEDecodeCube (-> MESH). Reuses CLIP-L conditioning, CFGGuider/SamplerCustomAdvanced, and SaveGLB. Amp-Thread-ID: https://ampcode.com/threads/T-019ec361-addb-70d8-a74b-438ce8a1e096 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
4388eb781a
commit
01a8783bee
@ -1955,3 +1955,110 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
transformer_options.pop("ar_state", None)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _cube_process_logits(logits, top_p, generator):
|
||||
"""Token selection. top_p>=1 or <=0 -> greedy argmax (upstream default, deterministic)."""
|
||||
if top_p is None or top_p >= 1.0 or top_p <= 0.0:
|
||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||
sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
|
||||
remove = sorted_logits.softmax(dim=-1).cumsum(dim=-1) > top_p
|
||||
remove[..., 0] = False
|
||||
idx_remove = remove.scatter(-1, sorted_idx, remove)
|
||||
logits = logits.masked_fill(idx_remove, float("-inf"))
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
return torch.multinomial(probs, num_samples=1, generator=generator)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_cube(model, x, sigmas, extra_args=None, callback=None, disable=None, top_p=1.0):
|
||||
"""
|
||||
Autoregressive sampler for Roblox Cube3D shape GPT (DualStreamRoformer).
|
||||
|
||||
Not a diffusion sampler: the noised input `x` and `sigmas` values are ignored;
|
||||
only x's shape (batch, num_tokens) is used. Generates a 1024-long sequence of VQ
|
||||
token IDs from CLIP text conditioning, with upstream's linearly-decaying CFG and
|
||||
optional top-p. Plugs into SamplerCustomAdvanced via the SamplerCube node.
|
||||
|
||||
Faithful to cube3d.inference.engine.Engine.run_gpt:
|
||||
gamma_i = cfg * (T - i) / T ; logits = (1+gamma)*cond - gamma*uncond
|
||||
fp32 weights + bf16 autocast on cuda.
|
||||
"""
|
||||
import comfy.model_management
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
guider = model.inner_model # CFGGuider
|
||||
base_model = guider.inner_model # BaseModel (Cube3D)
|
||||
cube = base_model.diffusion_model
|
||||
cfg = getattr(guider, "cfg", 3.0)
|
||||
|
||||
def get_cond(name):
|
||||
conds = guider.conds.get(name, None)
|
||||
if not conds:
|
||||
return None
|
||||
return conds[0]["model_conds"]["c_crossattn"].cond
|
||||
|
||||
pos = get_cond("positive")
|
||||
neg = get_cond("negative")
|
||||
if pos is None:
|
||||
raise ValueError("sample_cube requires positive conditioning (CLIP-L text embeds).")
|
||||
|
||||
device = x.device
|
||||
weight_dtype = base_model.get_dtype()
|
||||
T = x.shape[1]
|
||||
use_cfg = (cfg is not None) and (cfg > 0.0) and (neg is not None)
|
||||
autocast_enabled = (device.type == "cuda")
|
||||
cache_dtype = torch.bfloat16 if autocast_enabled else weight_dtype
|
||||
|
||||
def add_bbox(c):
|
||||
if not getattr(cube, "use_bbox", False):
|
||||
return c
|
||||
bbox = torch.zeros((c.shape[0], 3), device=device, dtype=c.dtype)
|
||||
return torch.cat([c, cube.bbox_proj(bbox).unsqueeze(1)], dim=1)
|
||||
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
|
||||
cond = add_bbox(cube.encode_text(pos.to(device=device, dtype=weight_dtype)))
|
||||
if use_cfg:
|
||||
ucond = add_bbox(cube.encode_text(neg.to(device=device, dtype=weight_dtype)))
|
||||
cond = torch.cat([cond, ucond], dim=0)
|
||||
|
||||
bos = torch.full((cond.shape[0], 1), cube.shape_bos_id, dtype=torch.long, device=device)
|
||||
embed = cube.encode_token(bos)
|
||||
Bp, input_seq_len, dim = embed.shape
|
||||
embed_buffer = torch.zeros((Bp, input_seq_len + T, dim), dtype=embed.dtype, device=device)
|
||||
embed_buffer[:, :input_seq_len, :].copy_(embed)
|
||||
|
||||
kv_cache = cube.init_kv_cache(Bp, cond.shape[1], T + 1, cache_dtype, device)
|
||||
|
||||
num_codes = cube.vocab_size - 3
|
||||
seed = extra_args.get("seed", 0)
|
||||
generator = None
|
||||
if device.type != "mps":
|
||||
generator = torch.Generator(device=device).manual_seed(int(seed))
|
||||
|
||||
output_ids = []
|
||||
for i in trange(T, disable=disable):
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
curr_pos_id = torch.tensor([i], dtype=torch.long, device=device)
|
||||
logits = cube(embed_buffer, cond, kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=(i > 0))
|
||||
logits = logits[:, 0, :num_codes]
|
||||
|
||||
if use_cfg:
|
||||
cond_logits, uncond_logits = logits.float().chunk(2, dim=0)
|
||||
gamma = cfg * (T - i) / T
|
||||
logits = (1.0 + gamma) * cond_logits - gamma * uncond_logits
|
||||
else:
|
||||
logits = logits.float()
|
||||
|
||||
next_id = _cube_process_logits(logits, top_p, generator)
|
||||
output_ids.append(next_id)
|
||||
|
||||
next_embed = cube.encode_token(next_id)
|
||||
if use_cfg:
|
||||
next_embed = torch.cat([next_embed, next_embed], dim=0)
|
||||
embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1))
|
||||
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[0], "sigma_hat": sigmas[0], "denoised": x})
|
||||
|
||||
return torch.cat(output_ids, dim=1).to(torch.float32)
|
||||
|
||||
417
comfy/ldm/cube/gpt.py
Normal file
417
comfy/ldm/cube/gpt.py
Normal file
@ -0,0 +1,417 @@
|
||||
"""
|
||||
Native port of Roblox/cube's shape GPT (DualStreamRoformer).
|
||||
|
||||
Reference: https://github.com/Roblox/cube (cube3d/model/gpt/dual_stream_roformer.py
|
||||
and cube3d/model/transformers/*).
|
||||
|
||||
This is an autoregressive transformer over discrete VQ shape tokens, conditioned on
|
||||
CLIP text embeddings. It is NOT a diffusion model; it is driven by the dedicated
|
||||
`sample_cube` sampler (see comfy/k_diffusion/sampling.py), not KSampler.
|
||||
|
||||
The forward pass is kept faithful to upstream so token IDs match bit-for-bit:
|
||||
* rope_theta = 10000
|
||||
* per-head RMSNorm on Q and K
|
||||
* dual-stream (MM-DiT style) joint attention; last dual block is cond_pre_only
|
||||
* two separate RoPE frequency tensors (dual blocks offset cond tokens by S)
|
||||
* SwiGLU MLP, non-affine LayerNorm upcast to fp32
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Norms (faithful to cube3d/model/transformers/norm.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CubeLayerNorm(nn.Module):
|
||||
"""Non-affine LayerNorm that upcasts to fp32 then back (matches cube)."""
|
||||
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = (dim,)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
y = F.layer_norm(x.float(), self.dim, None, None, self.eps)
|
||||
return y.type_as(x)
|
||||
|
||||
|
||||
class CubeRMSNorm(nn.Module):
|
||||
"""Per-head RMSNorm with learnable weight, computed in fp32 (matches cube)."""
|
||||
|
||||
def __init__(self, dim, eps=1e-5, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x):
|
||||
xf = x.float()
|
||||
out = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
return (out * self.weight).type_as(x)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RoPE (faithful to cube3d/model/transformers/rope.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def apply_rotary_emb(x, freqs_cis, curr_pos_id=None):
|
||||
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
if curr_pos_id is None:
|
||||
freqs_cis = freqs_cis[:, -x.shape[2]:].unsqueeze(1)
|
||||
else:
|
||||
freqs_cis = freqs_cis[:, curr_pos_id, :].unsqueeze(1)
|
||||
y = torch.view_as_real(x_ * freqs_cis).flatten(3)
|
||||
return y.type_as(x)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim, t, theta=10000.0):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=t.device) / dim))
|
||||
freqs = torch.outer(t.contiguous().view(-1), freqs).reshape(*t.shape, -1)
|
||||
return torch.polar(torch.ones_like(freqs), freqs)
|
||||
|
||||
|
||||
def sdpa_with_rope(q, k, v, freqs_cis, attn_mask=None, curr_pos_id=None, is_causal=False):
|
||||
q = apply_rotary_emb(q, freqs_cis, curr_pos_id=curr_pos_id)
|
||||
k = apply_rotary_emb(k, freqs_cis, curr_pos_id=None)
|
||||
return F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0,
|
||||
is_causal=is_causal and attn_mask is None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KV cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Cache:
|
||||
def __init__(self, key_states, value_states):
|
||||
self.key_states = key_states
|
||||
self.value_states = value_states
|
||||
|
||||
def update(self, curr_pos_id, k, v):
|
||||
self.key_states.index_copy_(2, curr_pos_id, k)
|
||||
self.value_states.index_copy_(2, curr_pos_id, v)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared building blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SwiGLUMLP(nn.Module):
|
||||
def __init__(self, embed_dim, hidden_dim, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.gate_proj = operations.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.up_proj = operations.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.down_proj = operations.Linear(hidden_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class SelfAttentionWithRotaryEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert embed_dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
head_dim = embed_dim // num_heads
|
||||
self.c_qk = operations.Linear(embed_dim, 2 * embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.c_v = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_proj = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, freqs_cis, attn_mask=None, is_causal=False, kv_cache=None, curr_pos_id=None, decode=False):
|
||||
b, l, d = x.shape
|
||||
q, k = self.c_qk(x).chunk(2, dim=-1)
|
||||
v = self.c_v(x)
|
||||
q = q.view(b, l, self.num_heads, -1).transpose(1, 2)
|
||||
k = k.view(b, l, self.num_heads, -1).transpose(1, 2)
|
||||
v = v.view(b, l, self.num_heads, -1).transpose(1, 2)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
if kv_cache is not None:
|
||||
if not decode:
|
||||
kv_cache.key_states[:, :, :k.shape[2], :].copy_(k)
|
||||
kv_cache.value_states[:, :, :k.shape[2], :].copy_(v)
|
||||
else:
|
||||
kv_cache.update(curr_pos_id, k, v)
|
||||
k = kv_cache.key_states
|
||||
v = kv_cache.value_states
|
||||
y = sdpa_with_rope(q, k, v, freqs_cis=freqs_cis, attn_mask=attn_mask,
|
||||
curr_pos_id=curr_pos_id if decode else None, is_causal=is_causal)
|
||||
y = y.transpose(1, 2).contiguous().view(b, l, d)
|
||||
return self.c_proj(y)
|
||||
|
||||
|
||||
class DecoderLayerWithRotaryEmbedding(nn.Module):
|
||||
"""Single-stream decoder layer (shape tokens only)."""
|
||||
|
||||
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.ln_1 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.attn = SelfAttentionWithRotaryEmbedding(embed_dim, num_heads, bias=bias, eps=eps,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
self.ln_2 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, freqs_cis, attn_mask=None, is_causal=True, kv_cache=None, curr_pos_id=None, decode=False):
|
||||
x = x + self.attn(self.ln_1(x), freqs_cis=freqs_cis, attn_mask=attn_mask, is_causal=is_causal,
|
||||
kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=decode)
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dual-stream blocks (faithful to dual_stream_attention.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DismantledPreAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, query=True, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert embed_dim % num_heads == 0
|
||||
self.query = query
|
||||
head_dim = embed_dim // num_heads
|
||||
if query:
|
||||
self.c_qk = operations.Linear(embed_dim, 2 * embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
else:
|
||||
self.c_k = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
self.c_v = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def _to_mha(self, x):
|
||||
return x.view(*x.shape[:2], self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
def forward(self, x):
|
||||
if self.query:
|
||||
q, k = self.c_qk(x).chunk(2, dim=-1)
|
||||
q = self.q_norm(self._to_mha(q))
|
||||
else:
|
||||
q = None
|
||||
k = self.c_k(x)
|
||||
k = self.k_norm(self._to_mha(k))
|
||||
v = self._to_mha(self.c_v(x))
|
||||
return (q, k, v)
|
||||
|
||||
|
||||
class DismantledPostAttention(nn.Module):
|
||||
def __init__(self, embed_dim, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.c_proj = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.ln_3 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, a):
|
||||
x = x + self.c_proj(a)
|
||||
x = x + self.mlp(self.ln_3(x))
|
||||
return x
|
||||
|
||||
|
||||
class DualStreamAttentionWithRotaryEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, cond_pre_only=False, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.cond_pre_only = cond_pre_only
|
||||
self.pre_x = DismantledPreAttention(embed_dim, num_heads, query=True, bias=bias,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
self.pre_c = DismantledPreAttention(embed_dim, num_heads, query=not cond_pre_only, bias=bias,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, c, freqs_cis, attn_mask=None, is_causal=False, kv_cache=None, curr_pos_id=None, decode=False):
|
||||
if kv_cache is None or not decode:
|
||||
qkv_c = self.pre_c(c)
|
||||
qkv_x = self.pre_x(x)
|
||||
if self.cond_pre_only:
|
||||
q = qkv_x[0]
|
||||
else:
|
||||
q = torch.cat([qkv_c[0], qkv_x[0]], dim=2)
|
||||
k = torch.cat([qkv_c[1], qkv_x[1]], dim=2)
|
||||
v = torch.cat([qkv_c[2], qkv_x[2]], dim=2)
|
||||
else:
|
||||
is_causal = False
|
||||
q, k, v = self.pre_x(x)
|
||||
|
||||
if kv_cache is not None:
|
||||
if not decode:
|
||||
kv_cache.key_states[:, :, :k.shape[2], :].copy_(k)
|
||||
kv_cache.value_states[:, :, :k.shape[2], :].copy_(v)
|
||||
else:
|
||||
kv_cache.update(curr_pos_id, k, v)
|
||||
k = kv_cache.key_states
|
||||
v = kv_cache.value_states
|
||||
|
||||
if attn_mask is not None:
|
||||
if decode:
|
||||
attn_mask = attn_mask[..., curr_pos_id, :]
|
||||
else:
|
||||
attn_mask = attn_mask[..., -q.shape[2]:, :]
|
||||
|
||||
y = sdpa_with_rope(q, k, v, freqs_cis=freqs_cis, attn_mask=attn_mask,
|
||||
curr_pos_id=curr_pos_id if decode else None, is_causal=is_causal)
|
||||
y = y.transpose(1, 2).contiguous().view(x.shape[0], -1, x.shape[2])
|
||||
|
||||
if y.shape[1] == x.shape[1]:
|
||||
return y, None
|
||||
y_c, y_x = torch.split(y, [c.shape[1], x.shape[1]], dim=1)
|
||||
return y_x, y_c
|
||||
|
||||
|
||||
class DualStreamDecoderLayerWithRotaryEmbedding(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, cond_pre_only=False, bias=True, eps=1e-6,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.ln_1 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.ln_2 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.attn = DualStreamAttentionWithRotaryEmbedding(embed_dim, num_heads, cond_pre_only=cond_pre_only,
|
||||
bias=bias, dtype=dtype, device=device, operations=operations)
|
||||
self.post_1 = DismantledPostAttention(embed_dim, bias=bias, eps=eps, dtype=dtype, device=device, operations=operations)
|
||||
if not cond_pre_only:
|
||||
self.post_2 = DismantledPostAttention(embed_dim, bias=bias, eps=eps, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, c, freqs_cis, attn_mask=None, is_causal=True, kv_cache=None, curr_pos_id=None, decode=False):
|
||||
a_x, a_c = self.attn(
|
||||
self.ln_1(x),
|
||||
self.ln_2(c) if c is not None else None,
|
||||
freqs_cis=freqs_cis, attn_mask=attn_mask, is_causal=is_causal,
|
||||
kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=decode,
|
||||
)
|
||||
x = self.post_1(x, a_x)
|
||||
if a_c is not None:
|
||||
c = self.post_2(c, a_c)
|
||||
else:
|
||||
c = None
|
||||
return x, c
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DualStreamRoformer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DualStreamRoformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_layer=23,
|
||||
n_single_layer=1,
|
||||
rope_theta=10000,
|
||||
n_head=12,
|
||||
n_embd=1536,
|
||||
bias=True,
|
||||
eps=1e-6,
|
||||
shape_model_vocab_size=16384,
|
||||
shape_model_embed_dim=32,
|
||||
text_model_embed_dim=768,
|
||||
use_bbox=True,
|
||||
image_model=None, # detection key; unused
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.n_layer = n_layer
|
||||
self.n_single_layer = n_single_layer
|
||||
self.n_head = n_head
|
||||
self.n_embd = n_embd
|
||||
self.rope_theta = rope_theta
|
||||
self.head_dim = n_embd // n_head
|
||||
|
||||
self.text_proj = operations.Linear(text_model_embed_dim, n_embd, bias=bias, dtype=dtype, device=device)
|
||||
self.shape_proj = operations.Linear(shape_model_embed_dim, n_embd, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.vocab_size = shape_model_vocab_size
|
||||
self.shape_bos_id = self.vocab_size
|
||||
self.shape_eos_id = self.vocab_size + 1
|
||||
self.padding_id = self.vocab_size + 2
|
||||
self.vocab_size += 3
|
||||
|
||||
self.transformer = nn.ModuleDict(dict(
|
||||
wte=operations.Embedding(self.vocab_size, n_embd, padding_idx=self.padding_id, dtype=dtype, device=device),
|
||||
dual_blocks=nn.ModuleList([
|
||||
DualStreamDecoderLayerWithRotaryEmbedding(
|
||||
n_embd, n_head, cond_pre_only=(i == n_layer - 1), bias=bias, eps=eps,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
for i in range(n_layer)
|
||||
]),
|
||||
single_blocks=nn.ModuleList([
|
||||
DecoderLayerWithRotaryEmbedding(n_embd, n_head, bias=bias, eps=eps,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(n_single_layer)
|
||||
]),
|
||||
ln_f=CubeLayerNorm(n_embd, eps=eps),
|
||||
))
|
||||
|
||||
self.lm_head = operations.Linear(n_embd, self.vocab_size, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.use_bbox = use_bbox
|
||||
if use_bbox:
|
||||
self.bbox_proj = operations.Linear(3, n_embd, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def encode_text(self, text_embed):
|
||||
return self.text_proj(text_embed)
|
||||
|
||||
def encode_token(self, tokens):
|
||||
return self.transformer.wte(tokens)
|
||||
|
||||
def init_kv_cache(self, batch_size, cond_len, max_shape_tokens, dtype, device):
|
||||
max_all = cond_len + max_shape_tokens
|
||||
kv = [
|
||||
Cache(
|
||||
torch.zeros((batch_size, self.n_head, max_all, self.head_dim), dtype=dtype, device=device),
|
||||
torch.zeros((batch_size, self.n_head, max_all, self.head_dim), dtype=dtype, device=device),
|
||||
)
|
||||
for _ in range(len(self.transformer.dual_blocks))
|
||||
]
|
||||
kv += [
|
||||
Cache(
|
||||
torch.zeros((batch_size, self.n_head, max_shape_tokens, self.head_dim), dtype=dtype, device=device),
|
||||
torch.zeros((batch_size, self.n_head, max_shape_tokens, self.head_dim), dtype=dtype, device=device),
|
||||
)
|
||||
for _ in range(len(self.transformer.single_blocks))
|
||||
]
|
||||
return kv
|
||||
|
||||
def forward(self, embed, cond, kv_cache=None, curr_pos_id=None, decode=False):
|
||||
b, l = embed.shape[:2]
|
||||
s = cond.shape[1]
|
||||
device = embed.device
|
||||
|
||||
attn_mask = torch.tril(torch.ones(s + l, s + l, dtype=torch.bool, device=device))
|
||||
|
||||
position_ids = torch.arange(l, dtype=torch.long, device=device).unsqueeze(0).expand(b, -1)
|
||||
s_freqs_cis = precompute_freqs_cis(self.head_dim, position_ids, theta=self.rope_theta)
|
||||
|
||||
position_ids = torch.cat([
|
||||
torch.zeros([b, s], dtype=torch.long, device=device),
|
||||
position_ids,
|
||||
], dim=1)
|
||||
d_freqs_cis = precompute_freqs_cis(self.head_dim, position_ids, theta=self.rope_theta)
|
||||
|
||||
if kv_cache is not None and decode:
|
||||
embed = embed[:, curr_pos_id, :]
|
||||
|
||||
h = embed
|
||||
c = cond
|
||||
layer_idx = 0
|
||||
for block in self.transformer.dual_blocks:
|
||||
h, c = block(
|
||||
h, c=c, freqs_cis=d_freqs_cis, attn_mask=attn_mask, is_causal=True,
|
||||
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
|
||||
curr_pos_id=curr_pos_id + s if curr_pos_id is not None else None,
|
||||
decode=decode,
|
||||
)
|
||||
layer_idx += 1
|
||||
for block in self.transformer.single_blocks:
|
||||
h = block(
|
||||
h, freqs_cis=s_freqs_cis, attn_mask=None, is_causal=True,
|
||||
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
|
||||
curr_pos_id=curr_pos_id, decode=decode,
|
||||
)
|
||||
layer_idx += 1
|
||||
|
||||
h = self.transformer.ln_f(h)
|
||||
return self.lm_head(h)
|
||||
345
comfy/ldm/cube/vae.py
Normal file
345
comfy/ldm/cube/vae.py
Normal file
@ -0,0 +1,345 @@
|
||||
"""
|
||||
Native port of Roblox/cube's shape tokenizer decode path (OneDAutoEncoder).
|
||||
|
||||
Reference: https://github.com/Roblox/cube (cube3d/model/autoencoder/*).
|
||||
|
||||
Only the DECODE path is ported (token IDs -> latents -> occupancy grid -> mesh);
|
||||
the point-cloud encoder is not needed for text-to-3D generation. Encoder weights in
|
||||
the checkpoint are loaded with strict=False and ignored.
|
||||
|
||||
Module/parameter names mirror upstream so the checkpoint loads directly:
|
||||
embedder.weight
|
||||
bottleneck.block.{codebook, cb_weight, cb_bias, c_in, c_x, c_out, ...}
|
||||
decoder.{positional_encodings, blocks.N...}
|
||||
occupancy_decoder.{query_in, attn_out, ln_f, c_head}
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Norms
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CubeLayerNorm(nn.Module):
|
||||
"""LayerNorm upcasting to fp32. affine=False by default (no params)."""
|
||||
|
||||
def __init__(self, dim, eps=1e-6, elementwise_affine=False, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.dim = (dim,)
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
|
||||
self.bias = nn.Parameter(torch.zeros(dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x):
|
||||
w = self.weight.float() if self.weight is not None else None
|
||||
b = self.bias.float() if self.bias is not None else None
|
||||
y = F.layer_norm(x.float(), self.dim, w, b, self.eps)
|
||||
return y.type_as(x)
|
||||
|
||||
|
||||
class CubeRMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5, elementwise_affine=True, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.register_buffer("weight", torch.ones(dim), persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
xf = x.float()
|
||||
out = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
return (out * self.weight.float()).type_as(x)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fourier embedder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PhaseModulatedFourierEmbedder(nn.Module):
|
||||
def __init__(self, num_freqs, input_dim=3, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty(input_dim, num_freqs, dtype=dtype, device=device))
|
||||
carrier = (num_freqs / 8) ** torch.linspace(1, 0, num_freqs)
|
||||
carrier = (carrier + torch.linspace(0, 1, num_freqs)) * 2 * math.pi
|
||||
self.register_buffer("carrier", carrier, persistent=False)
|
||||
self.out_dim = input_dim * (num_freqs * 2 + 1)
|
||||
|
||||
def forward(self, x):
|
||||
m = x.float().unsqueeze(-1)
|
||||
w = self.weight.float()
|
||||
carrier = self.carrier.float()
|
||||
fm = (m * w).view(*x.shape[:-1], -1)
|
||||
pm = (m * 0.5 * math.pi + carrier).view(*x.shape[:-1], -1)
|
||||
return torch.cat([x, fm.cos() + pm.cos(), fm.sin() + pm.sin()], dim=-1).type_as(x)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attention building blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, embed_dim, hidden_dim, bias=True, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.up_proj = ops.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.down_proj = ops.Linear(hidden_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.act_fn = nn.GELU(approximate="none")
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.up_proj(x)))
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
assert embed_dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
head_dim = embed_dim // num_heads
|
||||
self.c_qk = ops.Linear(embed_dim, 2 * embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_v = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_proj = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, attn_mask=None, is_causal=False):
|
||||
b, l, d = x.shape
|
||||
q, k = self.c_qk(x).chunk(2, dim=-1)
|
||||
v = self.c_v(x)
|
||||
q = self.q_norm(q.view(b, l, self.num_heads, -1).transpose(1, 2))
|
||||
k = self.k_norm(k.view(b, l, self.num_heads, -1).transpose(1, 2))
|
||||
v = v.view(b, l, self.num_heads, -1).transpose(1, 2)
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0,
|
||||
is_causal=is_causal and attn_mask is None)
|
||||
y = y.transpose(1, 2).contiguous().view(b, l, d)
|
||||
return self.c_proj(y)
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, q_dim=None, kv_dim=None, bias=True, dtype=None, device=None):
|
||||
super().__init__()
|
||||
assert embed_dim % num_heads == 0
|
||||
q_dim = q_dim or embed_dim
|
||||
kv_dim = kv_dim or embed_dim
|
||||
self.c_q = ops.Linear(q_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_k = ops.Linear(kv_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_v = ops.Linear(kv_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.c_proj = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
q, k, v = self.c_q(x), self.c_k(c), self.c_v(c)
|
||||
b, l, d = q.shape
|
||||
s = k.shape[1]
|
||||
q = q.view(b, l, self.num_heads, -1).transpose(1, 2)
|
||||
k = k.view(b, s, self.num_heads, -1).transpose(1, 2)
|
||||
v = v.view(b, s, self.num_heads, -1).transpose(1, 2)
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
|
||||
y = y.transpose(1, 2).contiguous().view(b, l, d)
|
||||
return self.c_proj(y)
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.ln_1 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.attn = SelfAttention(embed_dim, num_heads, bias=bias, eps=eps, dtype=dtype, device=device)
|
||||
self.ln_2 = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.mlp = MLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, attn_mask=None, is_causal=False):
|
||||
x = x + self.attn(self.ln_1(x), attn_mask=attn_mask, is_causal=is_causal)
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class EncoderCrossAttentionLayer(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, q_dim=None, kv_dim=None, bias=True, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
q_dim = q_dim or embed_dim
|
||||
kv_dim = kv_dim or embed_dim
|
||||
self.attn = CrossAttention(embed_dim, num_heads, q_dim=q_dim, kv_dim=kv_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.ln_1 = CubeLayerNorm(q_dim, eps=eps)
|
||||
self.ln_2 = CubeLayerNorm(kv_dim, eps=eps)
|
||||
self.ln_f = CubeLayerNorm(embed_dim, eps=eps)
|
||||
self.mlp = MLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
x = x + self.attn(self.ln_1(x), self.ln_2(c), attn_mask=attn_mask)
|
||||
x = x + self.mlp(self.ln_f(x))
|
||||
return x
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim, embed_dim, bias=True, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.in_layer = ops.Linear(in_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = ops.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Spherical VQ (decode-only parts)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SphericalVectorQuantizer(nn.Module):
|
||||
def __init__(self, embed_dim, num_codes, width=None, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.num_codes = num_codes
|
||||
self.codebook = ops.Embedding(num_codes, embed_dim, dtype=dtype, device=device)
|
||||
width = width or embed_dim
|
||||
if width != embed_dim:
|
||||
self.c_in = ops.Linear(width, embed_dim, dtype=dtype, device=device)
|
||||
self.c_x = ops.Linear(width, embed_dim, dtype=dtype, device=device)
|
||||
self.c_out = ops.Linear(embed_dim, width, dtype=dtype, device=device)
|
||||
else:
|
||||
self.c_in = self.c_out = self.c_x = nn.Identity()
|
||||
self.norm = CubeRMSNorm(embed_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
# "kl" codebook regularization (released config)
|
||||
self.cb_weight = nn.Parameter(torch.ones([embed_dim], dtype=dtype, device=device))
|
||||
self.cb_bias = nn.Parameter(torch.zeros([embed_dim], dtype=dtype, device=device))
|
||||
|
||||
def cb_norm(self, x):
|
||||
return x * self.cb_weight + self.cb_bias
|
||||
|
||||
def get_codebook(self):
|
||||
return self.norm(self.cb_norm(self.codebook.weight))
|
||||
|
||||
def lookup_codebook(self, q):
|
||||
z_q = F.embedding(q, self.get_codebook())
|
||||
return self.c_out(z_q)
|
||||
|
||||
|
||||
class OneDBottleNeck(nn.Module):
|
||||
def __init__(self, block):
|
||||
super().__init__()
|
||||
self.block = block
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decoders
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OneDDecoder(nn.Module):
|
||||
def __init__(self, num_latents, width, num_heads, num_layers, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.register_buffer("query", torch.empty([0, width]), persistent=False)
|
||||
self.positional_encodings = nn.Parameter(torch.empty(num_latents, width, dtype=dtype, device=device))
|
||||
self.blocks = nn.ModuleList([
|
||||
EncoderLayer(width, num_heads, eps=eps, dtype=dtype, device=device)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, z):
|
||||
h = z + self.positional_encodings[:z.shape[1]].unsqueeze(0).to(z.dtype)
|
||||
for block in self.blocks:
|
||||
h = block(h)
|
||||
return h
|
||||
|
||||
|
||||
class OneDOccupancyDecoder(nn.Module):
|
||||
def __init__(self, embedder, out_features, width, num_heads, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.embedder = embedder
|
||||
self.query_in = MLPEmbedder(embedder.out_dim, width, dtype=dtype, device=device)
|
||||
self.attn_out = EncoderCrossAttentionLayer(width, num_heads, dtype=dtype, device=device)
|
||||
self.ln_f = CubeLayerNorm(width, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||
self.c_head = ops.Linear(width, out_features, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, queries, latents):
|
||||
x = self.query_in(self.embedder(queries))
|
||||
x = self.attn_out(x, latents)
|
||||
return self.c_head(self.ln_f(x))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level shape VAE
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def generate_dense_grid_points(bbox_min, bbox_max, resolution_base, indexing="ij"):
|
||||
length = bbox_max - bbox_min
|
||||
num_cells = np.exp2(resolution_base)
|
||||
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
||||
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
||||
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
||||
xs, ys, zs = np.meshgrid(x, y, z, indexing=indexing)
|
||||
xyz = np.stack((xs, ys, zs), axis=-1).reshape(-1, 3)
|
||||
grid_size = [int(num_cells) + 1] * 3
|
||||
return xyz, grid_size, length
|
||||
|
||||
|
||||
class CubeShapeVAE(nn.Module):
|
||||
"""Decode-only OneDAutoEncoder. Encoder weights load with strict=False (ignored)."""
|
||||
|
||||
def __init__(self, num_encoder_latents=1024, embed_dim=32, width=768, num_heads=12,
|
||||
num_freqs=128, num_decoder_layers=24, num_codes=16384, out_dim=1, eps=1e-6,
|
||||
dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.cfg_num_encoder_latents = num_encoder_latents
|
||||
self.cfg_num_codes = num_codes
|
||||
self.embedder = PhaseModulatedFourierEmbedder(num_freqs=num_freqs, input_dim=3, dtype=dtype, device=device)
|
||||
self.bottleneck = OneDBottleNeck(
|
||||
SphericalVectorQuantizer(embed_dim, num_codes, width, dtype=dtype, device=device)
|
||||
)
|
||||
self.decoder = OneDDecoder(num_encoder_latents, width, num_heads, num_decoder_layers,
|
||||
eps=eps, dtype=dtype, device=device)
|
||||
self.occupancy_decoder = OneDOccupancyDecoder(self.embedder, out_dim, width, num_heads,
|
||||
eps=eps, dtype=dtype, device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_indices(self, shape_ids):
|
||||
z_q = self.bottleneck.block.lookup_codebook(shape_ids)
|
||||
return self.decoder(z_q)
|
||||
|
||||
@torch.no_grad()
|
||||
def query(self, queries, latents):
|
||||
return self.occupancy_decoder(queries, latents).squeeze(-1)
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_geometry(self, latents, bounds=(-1.05, -1.05, -1.05, 1.05, 1.05, 1.05),
|
||||
resolution_base=8.0, chunk_size=100_000):
|
||||
bbox_min = np.array(bounds[0:3])
|
||||
bbox_max = np.array(bounds[3:6])
|
||||
bbox_size = bbox_max - bbox_min
|
||||
|
||||
xyz, grid_size, _ = generate_dense_grid_points(bbox_min, bbox_max, resolution_base, indexing="ij")
|
||||
xyz = torch.from_numpy(xyz)
|
||||
batch_size = latents.shape[0]
|
||||
batch_logits = []
|
||||
for start in range(0, xyz.shape[0], chunk_size):
|
||||
queries = xyz[start:start + chunk_size, :]
|
||||
n = queries.shape[0]
|
||||
if start > 0 and n < chunk_size:
|
||||
queries = F.pad(queries, [0, 0, 0, chunk_size - n])
|
||||
bq = queries.unsqueeze(0).expand(batch_size, -1, -1).to(latents)
|
||||
batch_logits.append(self.query(bq, latents)[:, :n])
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim=1).detach().view(
|
||||
batch_size, grid_size[0], grid_size[1], grid_size[2]).float()
|
||||
return grid_logits, grid_size, bbox_size, bbox_min
|
||||
|
||||
|
||||
def grid_logits_to_mesh(grid_logit, grid_size, bbox_size, bbox_min, level=0.0):
|
||||
"""Marching cubes via skimage (matches upstream CPU fallback path)."""
|
||||
from skimage import measure
|
||||
vertices, faces, _, _ = measure.marching_cubes(grid_logit.cpu().numpy(), level, method="lewiner")
|
||||
vertices = vertices / np.array(grid_size) * bbox_size + bbox_min
|
||||
faces = faces[:, [2, 1, 0]]
|
||||
return vertices.astype(np.float32), np.ascontiguousarray(faces)
|
||||
@ -44,6 +44,7 @@ import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.wan.model_animate
|
||||
import comfy.ldm.wan.ar_model
|
||||
import comfy.ldm.cube.gpt
|
||||
import comfy.ldm.wan.model_wandancer
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.triposplat.model
|
||||
@ -1903,6 +1904,26 @@ class Hunyuan3Dv2(BaseModel):
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class Cube3D(BaseModel):
|
||||
"""Roblox Cube3D shape GPT (autoregressive). Generation goes through the
|
||||
dedicated `cube` sampler (SamplerCustomAdvanced), never KSampler/apply_model."""
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cube.gpt.DualStreamRoformer)
|
||||
|
||||
def _apply_model(self, *args, **kwargs):
|
||||
raise RuntimeError(
|
||||
"Cube3D is an autoregressive token model. Use the 'cube' sampler "
|
||||
"(SamplerCube + SamplerCustomAdvanced), not KSampler."
|
||||
)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class Hunyuan3Dv2_1(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain)
|
||||
|
||||
@ -654,6 +654,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}shape_proj.weight'.format(key_prefix) in state_dict_keys and '{}lm_head.weight'.format(key_prefix) in state_dict_keys: # Roblox Cube3D shape GPT
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "cube3d"
|
||||
n_embd = state_dict['{}transformer.wte.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["n_embd"] = n_embd
|
||||
dit_config["shape_model_vocab_size"] = state_dict['{}transformer.wte.weight'.format(key_prefix)].shape[0] - 3
|
||||
dit_config["n_layer"] = count_blocks(state_dict_keys, '{}transformer.dual_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["n_single_layer"] = count_blocks(state_dict_keys, '{}transformer.single_blocks.'.format(key_prefix) + '{}.')
|
||||
head_dim = state_dict['{}transformer.dual_blocks.0.attn.pre_x.q_norm.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["n_head"] = n_embd // head_dim
|
||||
dit_config["shape_model_embed_dim"] = state_dict['{}shape_proj.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["text_model_embed_dim"] = state_dict['{}text_proj.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["use_bbox"] = '{}bbox_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["bias"] = '{}text_proj.bias'.format(key_prefix) in state_dict_keys
|
||||
dit_config["rope_theta"] = 10000
|
||||
return dit_config
|
||||
|
||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
|
||||
dit_config = {}
|
||||
|
||||
21
comfy/sd.py
21
comfy/sd.py
@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.cube.vae
|
||||
import comfy.ldm.triposplat.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.cogvideo.vae
|
||||
@ -489,6 +490,7 @@ class VAE:
|
||||
self.disable_offload = False
|
||||
self.not_video = False
|
||||
self.size = None
|
||||
self.cube3d = False
|
||||
|
||||
self.downscale_index_formula = None
|
||||
self.upscale_index_formula = None
|
||||
@ -777,6 +779,25 @@ class VAE:
|
||||
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE()
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
# Roblox Cube3D shape tokenizer (OneDAutoEncoder, decode-only)
|
||||
elif "bottleneck.block.codebook.weight" in sd:
|
||||
self.cube3d = True
|
||||
self.latent_dim = 1
|
||||
embed_dim = sd["bottleneck.block.codebook.weight"].shape[1]
|
||||
num_codes = sd["bottleneck.block.codebook.weight"].shape[0]
|
||||
width = sd["bottleneck.block.c_out.weight"].shape[0]
|
||||
num_encoder_latents = sd["decoder.positional_encodings"].shape[0]
|
||||
head_dim = sd["decoder.blocks.0.attn.q_norm.weight"].shape[0]
|
||||
num_heads = width // head_dim
|
||||
num_freqs = sd["embedder.weight"].shape[1]
|
||||
num_decoder_layers = len({k.split(".")[2] for k in sd if k.startswith("decoder.blocks.")})
|
||||
self.first_stage_model = comfy.ldm.cube.vae.CubeShapeVAE(
|
||||
num_encoder_latents=num_encoder_latents, embed_dim=embed_dim, width=width,
|
||||
num_heads=num_heads, num_freqs=num_freqs, num_decoder_layers=num_decoder_layers,
|
||||
num_codes=num_codes,
|
||||
)
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[1] * 768) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.float32]
|
||||
|
||||
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
|
||||
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
|
||||
|
||||
@ -1550,6 +1550,30 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2mini
|
||||
|
||||
|
||||
class Cube3D(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "cube3d",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {}
|
||||
|
||||
latent_format = latent_formats.LatentFormat
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
# Upstream keeps fp32 weights and uses bf16 autocast during the forward pass
|
||||
# (see sample_cube). Prefer fp32 weights for parity; bf16 is the low-VRAM fallback.
|
||||
supported_inference_dtypes = [torch.float32, torch.bfloat16]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.Cube3D(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class TripoSplat(supported_models_base.BASE):
|
||||
# Image -> 3D gaussian splat flow denoiser
|
||||
unet_config = {
|
||||
@ -2292,6 +2316,7 @@ models = [
|
||||
Hunyuan3Dv2mini,
|
||||
Hunyuan3Dv2,
|
||||
Hunyuan3Dv2_1,
|
||||
Cube3D,
|
||||
TripoSplat,
|
||||
HiDream,
|
||||
HiDreamO1,
|
||||
|
||||
153
comfy_extras/nodes_cube.py
Normal file
153
comfy_extras/nodes_cube.py
Normal file
@ -0,0 +1,153 @@
|
||||
"""
|
||||
Nodes for native Roblox Cube3D text-to-3D support.
|
||||
|
||||
Graph:
|
||||
CLIPLoader(clip-l) -> CLIPTextEncode -> CONDITIONING
|
||||
UNETLoader(shape_gpt) -> MODEL --\
|
||||
VAELoader(shape_tokenizer) -> VAE -> CubeCodebookPatch -> MODEL
|
||||
CFGGuider(MODEL, pos, neg, cfg) + SamplerCube + (trivial sigmas) + EmptyCubeLatent
|
||||
-> SamplerCustomAdvanced -> LATENT (token IDs)
|
||||
VAEDecodeCube(VAE, LATENT) -> MESH -> SaveGLB
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.ldm.cube.vae
|
||||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
from comfy_extras.nodes_save_3d import pack_variable_mesh_batch
|
||||
|
||||
|
||||
class EmptyCubeLatent(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyCubeLatent",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Int.Input("num_tokens", default=1024, min=1, max=8192,
|
||||
tooltip="Shape token sequence length. Must match the tokenizer "
|
||||
"(1024 for cube3d-v0.5, 512 for v0.1)."),
|
||||
IO.Int.Input("batch_size", default=1, min=1, max=64),
|
||||
],
|
||||
outputs=[IO.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, num_tokens, batch_size) -> IO.NodeOutput:
|
||||
latent = torch.zeros([batch_size, num_tokens], device=comfy.model_management.intermediate_device())
|
||||
return IO.NodeOutput({"samples": latent, "type": "cube_tokens"})
|
||||
|
||||
|
||||
class CubeCodebookPatch(IO.ComfyNode):
|
||||
"""Inject the projected VQ codebook into the GPT token-embedding table.
|
||||
|
||||
Upstream copies shape_proj(tokenizer.codebook) into wte.weight[:num_codes] at load
|
||||
time; without it generation is garbage. Done here as a ModelPatcher object patch so
|
||||
it composes with normal model loading/offload."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="CubeCodebookPatch",
|
||||
display_name="Cube Codebook Patch",
|
||||
category="advanced/model",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Vae.Input("vae"),
|
||||
],
|
||||
outputs=[IO.Model.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, vae) -> IO.NodeOutput:
|
||||
gpt = model.get_model_object("diffusion_model")
|
||||
codebook = vae.first_stage_model.bottleneck.block.get_codebook() # (num_codes, embed_dim) fp32
|
||||
w = gpt.shape_proj.weight
|
||||
proj = gpt.shape_proj(codebook.to(device=w.device, dtype=w.dtype)) # (num_codes, n_embd)
|
||||
|
||||
old = model.get_model_object("diffusion_model.transformer.wte.weight")
|
||||
new = old.clone()
|
||||
new[:proj.shape[0]] = proj.to(device=new.device, dtype=new.dtype)
|
||||
|
||||
m = model.clone()
|
||||
m.add_object_patch("diffusion_model.transformer.wte.weight",
|
||||
torch.nn.Parameter(new, requires_grad=False))
|
||||
return IO.NodeOutput(m)
|
||||
|
||||
|
||||
class SamplerCube(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SamplerCube",
|
||||
display_name="Sampler Cube (autoregressive)",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
IO.Float.Input("top_p", default=1.0, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="1.0 = deterministic greedy (upstream default). "
|
||||
"<1.0 enables nucleus sampling."),
|
||||
],
|
||||
outputs=[IO.Sampler.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, top_p) -> IO.NodeOutput:
|
||||
return IO.NodeOutput(comfy.samplers.ksampler("cube", {"top_p": top_p}))
|
||||
|
||||
|
||||
class VAEDecodeCube(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeCube",
|
||||
display_name="VAE Decode Cube (3D)",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Float.Input("resolution_base", default=8.0, min=4.0, max=10.0, step=0.5,
|
||||
tooltip="Grid cells per axis = 2^resolution_base. 8.0 matches "
|
||||
"upstream default (257^3 grid)."),
|
||||
IO.Int.Input("chunk_size", default=100000, min=1000, max=2000000, advanced=True),
|
||||
],
|
||||
outputs=[IO.Mesh.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, samples, resolution_base, chunk_size) -> IO.NodeOutput:
|
||||
comfy.model_management.load_models_gpu([vae.patcher])
|
||||
tok = vae.first_stage_model
|
||||
ids = samples["samples"][:, :tok.cfg_num_encoder_latents].long()
|
||||
ids = ids.clamp(0, tok.cfg_num_codes - 1).to(vae.device)
|
||||
|
||||
latents = tok.decode_indices(ids)
|
||||
grid, grid_size, bbox_size, bbox_min = tok.extract_geometry(
|
||||
latents, resolution_base=resolution_base, chunk_size=chunk_size)
|
||||
|
||||
verts_list, faces_list = [], []
|
||||
for i in range(grid.shape[0]):
|
||||
v, f = comfy.ldm.cube.vae.grid_logits_to_mesh(grid[i], grid_size, bbox_size, bbox_min)
|
||||
verts_list.append(torch.from_numpy(v))
|
||||
faces_list.append(torch.from_numpy(f.astype(np.int64)))
|
||||
|
||||
mesh = pack_variable_mesh_batch(verts_list, faces_list)
|
||||
return IO.NodeOutput(mesh)
|
||||
|
||||
|
||||
class CubeExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
EmptyCubeLatent,
|
||||
CubeCodebookPatch,
|
||||
SamplerCube,
|
||||
VAEDecodeCube,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> CubeExtension:
|
||||
return CubeExtension()
|
||||
Loading…
Reference in New Issue
Block a user