mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 16:29:25 +08:00
Cube3D is an autoregressive VQ-token shape model (DualStreamRoformer) plus a VQ-VAE shape tokenizer (OneDAutoEncoder), not a diffusion model. It is wired natively following the Causal-WAN AR-video pattern: the GPT loads as a normal MODEL and generation runs through a dedicated 'cube' sampler instead of KSampler. - comfy/ldm/cube/gpt.py: DualStreamRoformer port (dual-stream RoPE attention, per-head RMSNorm, SwiGLU, KV cache; rope_theta=10000). - comfy/ldm/cube/vae.py: OneDAutoEncoder decode path (codebook lookup, decoder, occupancy decoder, dense-grid extraction + skimage marching cubes). - model_detection/supported_models/model_base: register shape_gpt as Cube3D MODEL (dims inferred from state dict; apply_model guarded to point at SamplerCube). - sd.py: detect shape_tokenizer and build CubeShapeVAE. - k_diffusion/sampling.py: sample_cube autoregressive sampler (decaying CFG + optional top-p), faithful to upstream Engine.run_gpt. - comfy_extras/nodes_cube.py: EmptyCubeLatent, CubeCodebookPatch (inject VQ codebook into wte), SamplerCube, VAEDecodeCube (-> MESH). Reuses CLIP-L conditioning, CFGGuider/SamplerCustomAdvanced, and SaveGLB. Amp-Thread-ID: https://ampcode.com/threads/T-019ec361-addb-70d8-a74b-438ce8a1e096 Co-authored-by: Amp <amp@ampcode.com>
418 lines
17 KiB
Python
418 lines
17 KiB
Python
"""
|
|
Native port of Roblox/cube's shape GPT (DualStreamRoformer).
|
|
|
|
Reference: https://github.com/Roblox/cube (cube3d/model/gpt/dual_stream_roformer.py
|
|
and cube3d/model/transformers/*).
|
|
|
|
This is an autoregressive transformer over discrete VQ shape tokens, conditioned on
|
|
CLIP text embeddings. It is NOT a diffusion model; it is driven by the dedicated
|
|
`sample_cube` sampler (see comfy/k_diffusion/sampling.py), not KSampler.
|
|
|
|
The forward pass is kept faithful to upstream so token IDs match bit-for-bit:
|
|
* rope_theta = 10000
|
|
* per-head RMSNorm on Q and K
|
|
* dual-stream (MM-DiT style) joint attention; last dual block is cond_pre_only
|
|
* two separate RoPE frequency tensors (dual blocks offset cond tokens by S)
|
|
* SwiGLU MLP, non-affine LayerNorm upcast to fp32
|
|
"""
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Norms (faithful to cube3d/model/transformers/norm.py)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class CubeLayerNorm(nn.Module):
|
|
"""Non-affine LayerNorm that upcasts to fp32 then back (matches cube)."""
|
|
|
|
def __init__(self, dim, eps=1e-6):
|
|
super().__init__()
|
|
self.dim = (dim,)
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
y = F.layer_norm(x.float(), self.dim, None, None, self.eps)
|
|
return y.type_as(x)
|
|
|
|
|
|
class CubeRMSNorm(nn.Module):
|
|
"""Per-head RMSNorm with learnable weight, computed in fp32 (matches cube)."""
|
|
|
|
def __init__(self, dim, eps=1e-5, dtype=None, device=None):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
|
|
|
|
def forward(self, x):
|
|
xf = x.float()
|
|
out = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
return (out * self.weight).type_as(x)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RoPE (faithful to cube3d/model/transformers/rope.py)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def apply_rotary_emb(x, freqs_cis, curr_pos_id=None):
|
|
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
|
if curr_pos_id is None:
|
|
freqs_cis = freqs_cis[:, -x.shape[2]:].unsqueeze(1)
|
|
else:
|
|
freqs_cis = freqs_cis[:, curr_pos_id, :].unsqueeze(1)
|
|
y = torch.view_as_real(x_ * freqs_cis).flatten(3)
|
|
return y.type_as(x)
|
|
|
|
|
|
def precompute_freqs_cis(dim, t, theta=10000.0):
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=t.device) / dim))
|
|
freqs = torch.outer(t.contiguous().view(-1), freqs).reshape(*t.shape, -1)
|
|
return torch.polar(torch.ones_like(freqs), freqs)
|
|
|
|
|
|
def sdpa_with_rope(q, k, v, freqs_cis, attn_mask=None, curr_pos_id=None, is_causal=False):
|
|
q = apply_rotary_emb(q, freqs_cis, curr_pos_id=curr_pos_id)
|
|
k = apply_rotary_emb(k, freqs_cis, curr_pos_id=None)
|
|
return F.scaled_dot_product_attention(
|
|
q, k, v, attn_mask=attn_mask, dropout_p=0.0,
|
|
is_causal=is_causal and attn_mask is None,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# KV cache
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class Cache:
|
|
def __init__(self, key_states, value_states):
|
|
self.key_states = key_states
|
|
self.value_states = value_states
|
|
|
|
def update(self, curr_pos_id, k, v):
|
|
self.key_states.index_copy_(2, curr_pos_id, k)
|
|
self.value_states.index_copy_(2, curr_pos_id, v)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Shared building blocks
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class SwiGLUMLP(nn.Module):
|
|
def __init__(self, embed_dim, hidden_dim, bias=True, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.gate_proj = operations.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
|
self.up_proj = operations.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
|
self.down_proj = operations.Linear(hidden_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
|
|
|
def forward(self, x):
|
|
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
|
|
|
|
class SelfAttentionWithRotaryEmbedding(nn.Module):
|
|
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
assert embed_dim % num_heads == 0
|
|
self.num_heads = num_heads
|
|
head_dim = embed_dim // num_heads
|
|
self.c_qk = operations.Linear(embed_dim, 2 * embed_dim, bias=False, dtype=dtype, device=device)
|
|
self.c_v = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
|
self.c_proj = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
|
self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
|
self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
|
|
|
def forward(self, x, freqs_cis, attn_mask=None, is_causal=False, kv_cache=None, curr_pos_id=None, decode=False):
|
|
b, l, d = x.shape
|
|
q, k = self.c_qk(x).chunk(2, dim=-1)
|
|
v = self.c_v(x)
|
|
q = q.view(b, l, self.num_heads, -1).transpose(1, 2)
|
|
k = k.view(b, l, self.num_heads, -1).transpose(1, 2)
|
|
v = v.view(b, l, self.num_heads, -1).transpose(1, 2)
|
|
q = self.q_norm(q)
|
|
k = self.k_norm(k)
|
|
if kv_cache is not None:
|
|
if not decode:
|
|
kv_cache.key_states[:, :, :k.shape[2], :].copy_(k)
|
|
kv_cache.value_states[:, :, :k.shape[2], :].copy_(v)
|
|
else:
|
|
kv_cache.update(curr_pos_id, k, v)
|
|
k = kv_cache.key_states
|
|
v = kv_cache.value_states
|
|
y = sdpa_with_rope(q, k, v, freqs_cis=freqs_cis, attn_mask=attn_mask,
|
|
curr_pos_id=curr_pos_id if decode else None, is_causal=is_causal)
|
|
y = y.transpose(1, 2).contiguous().view(b, l, d)
|
|
return self.c_proj(y)
|
|
|
|
|
|
class DecoderLayerWithRotaryEmbedding(nn.Module):
|
|
"""Single-stream decoder layer (shape tokens only)."""
|
|
|
|
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.ln_1 = CubeLayerNorm(embed_dim, eps=eps)
|
|
self.attn = SelfAttentionWithRotaryEmbedding(embed_dim, num_heads, bias=bias, eps=eps,
|
|
dtype=dtype, device=device, operations=operations)
|
|
self.ln_2 = CubeLayerNorm(embed_dim, eps=eps)
|
|
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward(self, x, freqs_cis, attn_mask=None, is_causal=True, kv_cache=None, curr_pos_id=None, decode=False):
|
|
x = x + self.attn(self.ln_1(x), freqs_cis=freqs_cis, attn_mask=attn_mask, is_causal=is_causal,
|
|
kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=decode)
|
|
x = x + self.mlp(self.ln_2(x))
|
|
return x
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dual-stream blocks (faithful to dual_stream_attention.py)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class DismantledPreAttention(nn.Module):
|
|
def __init__(self, embed_dim, num_heads, query=True, bias=True, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
assert embed_dim % num_heads == 0
|
|
self.query = query
|
|
head_dim = embed_dim // num_heads
|
|
if query:
|
|
self.c_qk = operations.Linear(embed_dim, 2 * embed_dim, bias=False, dtype=dtype, device=device)
|
|
self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
|
else:
|
|
self.c_k = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
|
self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
|
self.c_v = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
|
self.num_heads = num_heads
|
|
|
|
def _to_mha(self, x):
|
|
return x.view(*x.shape[:2], self.num_heads, -1).transpose(1, 2)
|
|
|
|
def forward(self, x):
|
|
if self.query:
|
|
q, k = self.c_qk(x).chunk(2, dim=-1)
|
|
q = self.q_norm(self._to_mha(q))
|
|
else:
|
|
q = None
|
|
k = self.c_k(x)
|
|
k = self.k_norm(self._to_mha(k))
|
|
v = self._to_mha(self.c_v(x))
|
|
return (q, k, v)
|
|
|
|
|
|
class DismantledPostAttention(nn.Module):
|
|
def __init__(self, embed_dim, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.c_proj = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
|
self.ln_3 = CubeLayerNorm(embed_dim, eps=eps)
|
|
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward(self, x, a):
|
|
x = x + self.c_proj(a)
|
|
x = x + self.mlp(self.ln_3(x))
|
|
return x
|
|
|
|
|
|
class DualStreamAttentionWithRotaryEmbedding(nn.Module):
|
|
def __init__(self, embed_dim, num_heads, cond_pre_only=False, bias=True, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.cond_pre_only = cond_pre_only
|
|
self.pre_x = DismantledPreAttention(embed_dim, num_heads, query=True, bias=bias,
|
|
dtype=dtype, device=device, operations=operations)
|
|
self.pre_c = DismantledPreAttention(embed_dim, num_heads, query=not cond_pre_only, bias=bias,
|
|
dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward(self, x, c, freqs_cis, attn_mask=None, is_causal=False, kv_cache=None, curr_pos_id=None, decode=False):
|
|
if kv_cache is None or not decode:
|
|
qkv_c = self.pre_c(c)
|
|
qkv_x = self.pre_x(x)
|
|
if self.cond_pre_only:
|
|
q = qkv_x[0]
|
|
else:
|
|
q = torch.cat([qkv_c[0], qkv_x[0]], dim=2)
|
|
k = torch.cat([qkv_c[1], qkv_x[1]], dim=2)
|
|
v = torch.cat([qkv_c[2], qkv_x[2]], dim=2)
|
|
else:
|
|
is_causal = False
|
|
q, k, v = self.pre_x(x)
|
|
|
|
if kv_cache is not None:
|
|
if not decode:
|
|
kv_cache.key_states[:, :, :k.shape[2], :].copy_(k)
|
|
kv_cache.value_states[:, :, :k.shape[2], :].copy_(v)
|
|
else:
|
|
kv_cache.update(curr_pos_id, k, v)
|
|
k = kv_cache.key_states
|
|
v = kv_cache.value_states
|
|
|
|
if attn_mask is not None:
|
|
if decode:
|
|
attn_mask = attn_mask[..., curr_pos_id, :]
|
|
else:
|
|
attn_mask = attn_mask[..., -q.shape[2]:, :]
|
|
|
|
y = sdpa_with_rope(q, k, v, freqs_cis=freqs_cis, attn_mask=attn_mask,
|
|
curr_pos_id=curr_pos_id if decode else None, is_causal=is_causal)
|
|
y = y.transpose(1, 2).contiguous().view(x.shape[0], -1, x.shape[2])
|
|
|
|
if y.shape[1] == x.shape[1]:
|
|
return y, None
|
|
y_c, y_x = torch.split(y, [c.shape[1], x.shape[1]], dim=1)
|
|
return y_x, y_c
|
|
|
|
|
|
class DualStreamDecoderLayerWithRotaryEmbedding(nn.Module):
|
|
def __init__(self, embed_dim, num_heads, cond_pre_only=False, bias=True, eps=1e-6,
|
|
dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.ln_1 = CubeLayerNorm(embed_dim, eps=eps)
|
|
self.ln_2 = CubeLayerNorm(embed_dim, eps=eps)
|
|
self.attn = DualStreamAttentionWithRotaryEmbedding(embed_dim, num_heads, cond_pre_only=cond_pre_only,
|
|
bias=bias, dtype=dtype, device=device, operations=operations)
|
|
self.post_1 = DismantledPostAttention(embed_dim, bias=bias, eps=eps, dtype=dtype, device=device, operations=operations)
|
|
if not cond_pre_only:
|
|
self.post_2 = DismantledPostAttention(embed_dim, bias=bias, eps=eps, dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward(self, x, c, freqs_cis, attn_mask=None, is_causal=True, kv_cache=None, curr_pos_id=None, decode=False):
|
|
a_x, a_c = self.attn(
|
|
self.ln_1(x),
|
|
self.ln_2(c) if c is not None else None,
|
|
freqs_cis=freqs_cis, attn_mask=attn_mask, is_causal=is_causal,
|
|
kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=decode,
|
|
)
|
|
x = self.post_1(x, a_x)
|
|
if a_c is not None:
|
|
c = self.post_2(c, a_c)
|
|
else:
|
|
c = None
|
|
return x, c
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DualStreamRoformer
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class DualStreamRoformer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_layer=23,
|
|
n_single_layer=1,
|
|
rope_theta=10000,
|
|
n_head=12,
|
|
n_embd=1536,
|
|
bias=True,
|
|
eps=1e-6,
|
|
shape_model_vocab_size=16384,
|
|
shape_model_embed_dim=32,
|
|
text_model_embed_dim=768,
|
|
use_bbox=True,
|
|
image_model=None, # detection key; unused
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
self.dtype = dtype
|
|
self.n_layer = n_layer
|
|
self.n_single_layer = n_single_layer
|
|
self.n_head = n_head
|
|
self.n_embd = n_embd
|
|
self.rope_theta = rope_theta
|
|
self.head_dim = n_embd // n_head
|
|
|
|
self.text_proj = operations.Linear(text_model_embed_dim, n_embd, bias=bias, dtype=dtype, device=device)
|
|
self.shape_proj = operations.Linear(shape_model_embed_dim, n_embd, bias=True, dtype=dtype, device=device)
|
|
|
|
self.vocab_size = shape_model_vocab_size
|
|
self.shape_bos_id = self.vocab_size
|
|
self.shape_eos_id = self.vocab_size + 1
|
|
self.padding_id = self.vocab_size + 2
|
|
self.vocab_size += 3
|
|
|
|
self.transformer = nn.ModuleDict(dict(
|
|
wte=operations.Embedding(self.vocab_size, n_embd, padding_idx=self.padding_id, dtype=dtype, device=device),
|
|
dual_blocks=nn.ModuleList([
|
|
DualStreamDecoderLayerWithRotaryEmbedding(
|
|
n_embd, n_head, cond_pre_only=(i == n_layer - 1), bias=bias, eps=eps,
|
|
dtype=dtype, device=device, operations=operations,
|
|
)
|
|
for i in range(n_layer)
|
|
]),
|
|
single_blocks=nn.ModuleList([
|
|
DecoderLayerWithRotaryEmbedding(n_embd, n_head, bias=bias, eps=eps,
|
|
dtype=dtype, device=device, operations=operations)
|
|
for _ in range(n_single_layer)
|
|
]),
|
|
ln_f=CubeLayerNorm(n_embd, eps=eps),
|
|
))
|
|
|
|
self.lm_head = operations.Linear(n_embd, self.vocab_size, bias=False, dtype=dtype, device=device)
|
|
|
|
self.use_bbox = use_bbox
|
|
if use_bbox:
|
|
self.bbox_proj = operations.Linear(3, n_embd, bias=True, dtype=dtype, device=device)
|
|
|
|
def encode_text(self, text_embed):
|
|
return self.text_proj(text_embed)
|
|
|
|
def encode_token(self, tokens):
|
|
return self.transformer.wte(tokens)
|
|
|
|
def init_kv_cache(self, batch_size, cond_len, max_shape_tokens, dtype, device):
|
|
max_all = cond_len + max_shape_tokens
|
|
kv = [
|
|
Cache(
|
|
torch.zeros((batch_size, self.n_head, max_all, self.head_dim), dtype=dtype, device=device),
|
|
torch.zeros((batch_size, self.n_head, max_all, self.head_dim), dtype=dtype, device=device),
|
|
)
|
|
for _ in range(len(self.transformer.dual_blocks))
|
|
]
|
|
kv += [
|
|
Cache(
|
|
torch.zeros((batch_size, self.n_head, max_shape_tokens, self.head_dim), dtype=dtype, device=device),
|
|
torch.zeros((batch_size, self.n_head, max_shape_tokens, self.head_dim), dtype=dtype, device=device),
|
|
)
|
|
for _ in range(len(self.transformer.single_blocks))
|
|
]
|
|
return kv
|
|
|
|
def forward(self, embed, cond, kv_cache=None, curr_pos_id=None, decode=False):
|
|
b, l = embed.shape[:2]
|
|
s = cond.shape[1]
|
|
device = embed.device
|
|
|
|
attn_mask = torch.tril(torch.ones(s + l, s + l, dtype=torch.bool, device=device))
|
|
|
|
position_ids = torch.arange(l, dtype=torch.long, device=device).unsqueeze(0).expand(b, -1)
|
|
s_freqs_cis = precompute_freqs_cis(self.head_dim, position_ids, theta=self.rope_theta)
|
|
|
|
position_ids = torch.cat([
|
|
torch.zeros([b, s], dtype=torch.long, device=device),
|
|
position_ids,
|
|
], dim=1)
|
|
d_freqs_cis = precompute_freqs_cis(self.head_dim, position_ids, theta=self.rope_theta)
|
|
|
|
if kv_cache is not None and decode:
|
|
embed = embed[:, curr_pos_id, :]
|
|
|
|
h = embed
|
|
c = cond
|
|
layer_idx = 0
|
|
for block in self.transformer.dual_blocks:
|
|
h, c = block(
|
|
h, c=c, freqs_cis=d_freqs_cis, attn_mask=attn_mask, is_causal=True,
|
|
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
|
|
curr_pos_id=curr_pos_id + s if curr_pos_id is not None else None,
|
|
decode=decode,
|
|
)
|
|
layer_idx += 1
|
|
for block in self.transformer.single_blocks:
|
|
h = block(
|
|
h, freqs_cis=s_freqs_cis, attn_mask=None, is_causal=True,
|
|
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
|
|
curr_pos_id=curr_pos_id, decode=decode,
|
|
)
|
|
layer_idx += 1
|
|
|
|
h = self.transformer.ln_f(h)
|
|
return self.lm_head(h)
|