mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-27 02:10:08 +08:00
Replace cube's bespoke complex-number RoPE (torch.polar / view_as_complex) with
ComfyUI's shared Flux rotary embedding (comfy.ldm.flux.math):
* precompute_freqs_cis now returns Flux's real rotation freqs via rope().
* apply_rotary_emb applies them via apply_rope1, which at inference dispatches to
comfy-kitchen's optimized apply_rope kernel (comfy.quant_ops.ck). q and k are
still rotated separately to preserve the decode-time position asymmetry.
The pairing convention (adjacent dims) and rotation math are identical, so token
outputs are unchanged. The only numerical difference is that rope() computes the
rotation angles in fp64 before casting to fp32 (cube's original used fp32), so output
now matches upstream to fp32 rounding (~1e-6 on rotated q/k in a standalone check)
rather than bit-for-bit. Greedy argmax token selection is unaffected.
Deviation note: this is a deliberate, documented divergence from a strict upstream
port, taken to gain the shared optimized kernel. Needs GPU parity re-validation on the
2x4090 box (kosin-X570-AORUS-ULTRA) before merge.
Co-authored-by: Amp <amp@ampcode.com>
Amp-Thread-ID: https://ampcode.com/threads/T-019f013b-5892-71b9-af6b-c2ef28c67d2b
437 lines
18 KiB
Python
437 lines
18 KiB
Python
"""
|
|
Native port of Roblox/cube's shape GPT (DualStreamRoformer).
|
|
|
|
Reference: https://github.com/Roblox/cube (cube3d/model/gpt/dual_stream_roformer.py
|
|
and cube3d/model/transformers/*).
|
|
|
|
This is an autoregressive transformer over discrete VQ shape tokens, conditioned on
|
|
CLIP text embeddings. It is NOT a diffusion model; it is driven by the dedicated
|
|
`sample_cube` sampler (see comfy/k_diffusion/sampling.py), not KSampler.
|
|
|
|
The forward pass is kept faithful to upstream so token IDs match:
|
|
* rope_theta = 10000
|
|
* per-head RMSNorm on Q and K
|
|
* dual-stream (MM-DiT style) joint attention; last dual block is cond_pre_only
|
|
* two separate RoPE frequency tensors (dual blocks offset cond tokens by S)
|
|
* SwiGLU MLP, non-affine LayerNorm upcast to fp32
|
|
|
|
RoPE reuses ComfyUI's shared Flux rotary embedding (comfy.ldm.flux.math) so it
|
|
benefits from comfy-kitchen's optimized apply_rope kernel; see the RoPE section below.
|
|
"""
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import comfy.ldm.flux.math
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Norms (faithful to cube3d/model/transformers/norm.py)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class CubeLayerNorm(nn.Module):
|
|
"""Non-affine LayerNorm that upcasts to fp32 then back (matches cube)."""
|
|
|
|
def __init__(self, dim, eps=1e-6):
|
|
super().__init__()
|
|
self.dim = (dim,)
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
y = F.layer_norm(x.float(), self.dim, None, None, self.eps)
|
|
return y.type_as(x)
|
|
|
|
|
|
class CubeRMSNorm(nn.Module):
|
|
"""Per-head RMSNorm with learnable weight, computed in fp32 (matches cube)."""
|
|
|
|
def __init__(self, dim, eps=1e-5, dtype=None, device=None):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
|
|
|
|
def forward(self, x):
|
|
xf = x.float()
|
|
out = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
return (out * self.weight).type_as(x)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RoPE
|
|
#
|
|
# Reuses ComfyUI's shared Flux rotary embedding (comfy.ldm.flux.math):
|
|
# * rope() builds the real-valued rotation freqs (cos/-sin/sin/cos), shaped
|
|
# (B, L, head_dim/2, 2, 2);
|
|
# * apply_rope1() applies them and, at inference, dispatches to comfy-kitchen's
|
|
# optimized apply_rope kernel (comfy.quant_ops.ck).
|
|
# This replaces cube's original complex-number RoPE (torch.polar / view_as_complex),
|
|
# which was numerically equivalent but bypassed the kernel. The pairing convention is
|
|
# identical (adjacent dims), so the rotation math is the same; the only difference is
|
|
# rope() computes the angles in fp64 before casting to fp32, so outputs match upstream
|
|
# to fp32 rounding rather than bit-for-bit.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def precompute_freqs_cis(dim, t, theta=10000.0):
|
|
# t: (B, L) integer position ids. Returns Flux-style real rotation freqs shaped
|
|
# (B, L, dim/2, 2, 2) for comfy.ldm.flux.math.apply_rope1.
|
|
return comfy.ldm.flux.math.rope(t, dim, theta)
|
|
|
|
|
|
def apply_rotary_emb(x, freqs_cis, curr_pos_id=None):
|
|
# x: (B, num_heads, L, head_dim). Select the rotation freqs for x's positions, add
|
|
# the head-broadcast axis, then apply via the shared Flux/comfy-kitchen op. q (the
|
|
# new token[s]) and k (the full sequence) are rotated separately because their
|
|
# lengths/positions differ during decode.
|
|
if curr_pos_id is None:
|
|
freqs_cis = freqs_cis[:, -x.shape[2]:]
|
|
else:
|
|
freqs_cis = freqs_cis[:, curr_pos_id]
|
|
freqs_cis = freqs_cis.unsqueeze(1)
|
|
return comfy.ldm.flux.math.apply_rope1(x, freqs_cis)
|
|
|
|
|
|
def sdpa_with_rope(q, k, v, freqs_cis, attn_mask=None, curr_pos_id=None, is_causal=False):
|
|
q = apply_rotary_emb(q, freqs_cis, curr_pos_id=curr_pos_id)
|
|
k = apply_rotary_emb(k, freqs_cis, curr_pos_id=None)
|
|
return F.scaled_dot_product_attention(
|
|
q, k, v, attn_mask=attn_mask, dropout_p=0.0,
|
|
is_causal=is_causal and attn_mask is None,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# KV cache
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class Cache:
|
|
def __init__(self, key_states, value_states):
|
|
self.key_states = key_states
|
|
self.value_states = value_states
|
|
|
|
def update(self, curr_pos_id, k, v):
|
|
self.key_states.index_copy_(2, curr_pos_id, k)
|
|
self.value_states.index_copy_(2, curr_pos_id, v)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Shared building blocks
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class SwiGLUMLP(nn.Module):
|
|
def __init__(self, embed_dim, hidden_dim, bias=True, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.gate_proj = operations.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
|
self.up_proj = operations.Linear(embed_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
|
self.down_proj = operations.Linear(hidden_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
|
|
|
def forward(self, x):
|
|
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
|
|
|
|
class SelfAttentionWithRotaryEmbedding(nn.Module):
|
|
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
assert embed_dim % num_heads == 0
|
|
self.num_heads = num_heads
|
|
head_dim = embed_dim // num_heads
|
|
self.c_qk = operations.Linear(embed_dim, 2 * embed_dim, bias=False, dtype=dtype, device=device)
|
|
self.c_v = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
|
self.c_proj = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
|
self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
|
self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
|
|
|
def forward(self, x, freqs_cis, attn_mask=None, is_causal=False, kv_cache=None, curr_pos_id=None, decode=False):
|
|
b, l, d = x.shape
|
|
q, k = self.c_qk(x).chunk(2, dim=-1)
|
|
v = self.c_v(x)
|
|
q = q.view(b, l, self.num_heads, -1).transpose(1, 2)
|
|
k = k.view(b, l, self.num_heads, -1).transpose(1, 2)
|
|
v = v.view(b, l, self.num_heads, -1).transpose(1, 2)
|
|
q = self.q_norm(q)
|
|
k = self.k_norm(k)
|
|
if kv_cache is not None:
|
|
if not decode:
|
|
kv_cache.key_states[:, :, :k.shape[2], :].copy_(k)
|
|
kv_cache.value_states[:, :, :k.shape[2], :].copy_(v)
|
|
else:
|
|
kv_cache.update(curr_pos_id, k, v)
|
|
k = kv_cache.key_states
|
|
v = kv_cache.value_states
|
|
y = sdpa_with_rope(q, k, v, freqs_cis=freqs_cis, attn_mask=attn_mask,
|
|
curr_pos_id=curr_pos_id if decode else None, is_causal=is_causal)
|
|
y = y.transpose(1, 2).contiguous().view(b, l, d)
|
|
return self.c_proj(y)
|
|
|
|
|
|
class DecoderLayerWithRotaryEmbedding(nn.Module):
|
|
"""Single-stream decoder layer (shape tokens only)."""
|
|
|
|
def __init__(self, embed_dim, num_heads, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.ln_1 = CubeLayerNorm(embed_dim, eps=eps)
|
|
self.attn = SelfAttentionWithRotaryEmbedding(embed_dim, num_heads, bias=bias, eps=eps,
|
|
dtype=dtype, device=device, operations=operations)
|
|
self.ln_2 = CubeLayerNorm(embed_dim, eps=eps)
|
|
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward(self, x, freqs_cis, attn_mask=None, is_causal=True, kv_cache=None, curr_pos_id=None, decode=False):
|
|
x = x + self.attn(self.ln_1(x), freqs_cis=freqs_cis, attn_mask=attn_mask, is_causal=is_causal,
|
|
kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=decode)
|
|
x = x + self.mlp(self.ln_2(x))
|
|
return x
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dual-stream blocks (faithful to dual_stream_attention.py)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class DismantledPreAttention(nn.Module):
|
|
def __init__(self, embed_dim, num_heads, query=True, bias=True, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
assert embed_dim % num_heads == 0
|
|
self.query = query
|
|
head_dim = embed_dim // num_heads
|
|
if query:
|
|
self.c_qk = operations.Linear(embed_dim, 2 * embed_dim, bias=False, dtype=dtype, device=device)
|
|
self.q_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
|
else:
|
|
self.c_k = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
|
self.k_norm = CubeRMSNorm(head_dim, dtype=dtype, device=device)
|
|
self.c_v = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
|
self.num_heads = num_heads
|
|
|
|
def _to_mha(self, x):
|
|
return x.view(*x.shape[:2], self.num_heads, -1).transpose(1, 2)
|
|
|
|
def forward(self, x):
|
|
if self.query:
|
|
q, k = self.c_qk(x).chunk(2, dim=-1)
|
|
q = self.q_norm(self._to_mha(q))
|
|
else:
|
|
q = None
|
|
k = self.c_k(x)
|
|
k = self.k_norm(self._to_mha(k))
|
|
v = self._to_mha(self.c_v(x))
|
|
return (q, k, v)
|
|
|
|
|
|
class DismantledPostAttention(nn.Module):
|
|
def __init__(self, embed_dim, bias=True, eps=1e-6, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.c_proj = operations.Linear(embed_dim, embed_dim, bias=bias, dtype=dtype, device=device)
|
|
self.ln_3 = CubeLayerNorm(embed_dim, eps=eps)
|
|
self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias, dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward(self, x, a):
|
|
x = x + self.c_proj(a)
|
|
x = x + self.mlp(self.ln_3(x))
|
|
return x
|
|
|
|
|
|
class DualStreamAttentionWithRotaryEmbedding(nn.Module):
|
|
def __init__(self, embed_dim, num_heads, cond_pre_only=False, bias=True, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.cond_pre_only = cond_pre_only
|
|
self.pre_x = DismantledPreAttention(embed_dim, num_heads, query=True, bias=bias,
|
|
dtype=dtype, device=device, operations=operations)
|
|
self.pre_c = DismantledPreAttention(embed_dim, num_heads, query=not cond_pre_only, bias=bias,
|
|
dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward(self, x, c, freqs_cis, attn_mask=None, is_causal=False, kv_cache=None, curr_pos_id=None, decode=False):
|
|
if kv_cache is None or not decode:
|
|
qkv_c = self.pre_c(c)
|
|
qkv_x = self.pre_x(x)
|
|
if self.cond_pre_only:
|
|
q = qkv_x[0]
|
|
else:
|
|
q = torch.cat([qkv_c[0], qkv_x[0]], dim=2)
|
|
k = torch.cat([qkv_c[1], qkv_x[1]], dim=2)
|
|
v = torch.cat([qkv_c[2], qkv_x[2]], dim=2)
|
|
else:
|
|
is_causal = False
|
|
q, k, v = self.pre_x(x)
|
|
|
|
if kv_cache is not None:
|
|
if not decode:
|
|
kv_cache.key_states[:, :, :k.shape[2], :].copy_(k)
|
|
kv_cache.value_states[:, :, :k.shape[2], :].copy_(v)
|
|
else:
|
|
kv_cache.update(curr_pos_id, k, v)
|
|
k = kv_cache.key_states
|
|
v = kv_cache.value_states
|
|
|
|
if attn_mask is not None:
|
|
if decode:
|
|
attn_mask = attn_mask[..., curr_pos_id, :]
|
|
else:
|
|
attn_mask = attn_mask[..., -q.shape[2]:, :]
|
|
|
|
y = sdpa_with_rope(q, k, v, freqs_cis=freqs_cis, attn_mask=attn_mask,
|
|
curr_pos_id=curr_pos_id if decode else None, is_causal=is_causal)
|
|
y = y.transpose(1, 2).contiguous().view(x.shape[0], -1, x.shape[2])
|
|
|
|
if y.shape[1] == x.shape[1]:
|
|
return y, None
|
|
y_c, y_x = torch.split(y, [c.shape[1], x.shape[1]], dim=1)
|
|
return y_x, y_c
|
|
|
|
|
|
class DualStreamDecoderLayerWithRotaryEmbedding(nn.Module):
|
|
def __init__(self, embed_dim, num_heads, cond_pre_only=False, bias=True, eps=1e-6,
|
|
dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.ln_1 = CubeLayerNorm(embed_dim, eps=eps)
|
|
self.ln_2 = CubeLayerNorm(embed_dim, eps=eps)
|
|
self.attn = DualStreamAttentionWithRotaryEmbedding(embed_dim, num_heads, cond_pre_only=cond_pre_only,
|
|
bias=bias, dtype=dtype, device=device, operations=operations)
|
|
self.post_1 = DismantledPostAttention(embed_dim, bias=bias, eps=eps, dtype=dtype, device=device, operations=operations)
|
|
if not cond_pre_only:
|
|
self.post_2 = DismantledPostAttention(embed_dim, bias=bias, eps=eps, dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward(self, x, c, freqs_cis, attn_mask=None, is_causal=True, kv_cache=None, curr_pos_id=None, decode=False):
|
|
a_x, a_c = self.attn(
|
|
self.ln_1(x),
|
|
self.ln_2(c) if c is not None else None,
|
|
freqs_cis=freqs_cis, attn_mask=attn_mask, is_causal=is_causal,
|
|
kv_cache=kv_cache, curr_pos_id=curr_pos_id, decode=decode,
|
|
)
|
|
x = self.post_1(x, a_x)
|
|
if a_c is not None:
|
|
c = self.post_2(c, a_c)
|
|
else:
|
|
c = None
|
|
return x, c
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DualStreamRoformer
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class DualStreamRoformer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_layer=23,
|
|
n_single_layer=1,
|
|
rope_theta=10000,
|
|
n_head=12,
|
|
n_embd=1536,
|
|
bias=True,
|
|
eps=1e-6,
|
|
shape_model_vocab_size=16384,
|
|
shape_model_embed_dim=32,
|
|
text_model_embed_dim=768,
|
|
use_bbox=True,
|
|
image_model=None, # detection key; unused
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
self.dtype = dtype
|
|
self.n_layer = n_layer
|
|
self.n_single_layer = n_single_layer
|
|
self.n_head = n_head
|
|
self.n_embd = n_embd
|
|
self.rope_theta = rope_theta
|
|
self.head_dim = n_embd // n_head
|
|
|
|
self.text_proj = operations.Linear(text_model_embed_dim, n_embd, bias=bias, dtype=dtype, device=device)
|
|
self.shape_proj = operations.Linear(shape_model_embed_dim, n_embd, bias=True, dtype=dtype, device=device)
|
|
|
|
self.vocab_size = shape_model_vocab_size
|
|
self.shape_bos_id = self.vocab_size
|
|
self.shape_eos_id = self.vocab_size + 1
|
|
self.padding_id = self.vocab_size + 2
|
|
self.vocab_size += 3
|
|
|
|
self.transformer = nn.ModuleDict(dict(
|
|
wte=operations.Embedding(self.vocab_size, n_embd, padding_idx=self.padding_id, dtype=dtype, device=device),
|
|
dual_blocks=nn.ModuleList([
|
|
DualStreamDecoderLayerWithRotaryEmbedding(
|
|
n_embd, n_head, cond_pre_only=(i == n_layer - 1), bias=bias, eps=eps,
|
|
dtype=dtype, device=device, operations=operations,
|
|
)
|
|
for i in range(n_layer)
|
|
]),
|
|
single_blocks=nn.ModuleList([
|
|
DecoderLayerWithRotaryEmbedding(n_embd, n_head, bias=bias, eps=eps,
|
|
dtype=dtype, device=device, operations=operations)
|
|
for _ in range(n_single_layer)
|
|
]),
|
|
ln_f=CubeLayerNorm(n_embd, eps=eps),
|
|
))
|
|
|
|
self.lm_head = operations.Linear(n_embd, self.vocab_size, bias=False, dtype=dtype, device=device)
|
|
|
|
self.use_bbox = use_bbox
|
|
if use_bbox:
|
|
self.bbox_proj = operations.Linear(3, n_embd, bias=True, dtype=dtype, device=device)
|
|
|
|
def encode_text(self, text_embed):
|
|
return self.text_proj(text_embed)
|
|
|
|
def encode_token(self, tokens):
|
|
return self.transformer.wte(tokens)
|
|
|
|
def init_kv_cache(self, batch_size, cond_len, max_shape_tokens, dtype, device):
|
|
max_all = cond_len + max_shape_tokens
|
|
kv = [
|
|
Cache(
|
|
torch.zeros((batch_size, self.n_head, max_all, self.head_dim), dtype=dtype, device=device),
|
|
torch.zeros((batch_size, self.n_head, max_all, self.head_dim), dtype=dtype, device=device),
|
|
)
|
|
for _ in range(len(self.transformer.dual_blocks))
|
|
]
|
|
kv += [
|
|
Cache(
|
|
torch.zeros((batch_size, self.n_head, max_shape_tokens, self.head_dim), dtype=dtype, device=device),
|
|
torch.zeros((batch_size, self.n_head, max_shape_tokens, self.head_dim), dtype=dtype, device=device),
|
|
)
|
|
for _ in range(len(self.transformer.single_blocks))
|
|
]
|
|
return kv
|
|
|
|
def forward(self, embed, cond, kv_cache=None, curr_pos_id=None, decode=False):
|
|
b, l = embed.shape[:2]
|
|
s = cond.shape[1]
|
|
device = embed.device
|
|
|
|
attn_mask = torch.tril(torch.ones(s + l, s + l, dtype=torch.bool, device=device))
|
|
|
|
position_ids = torch.arange(l, dtype=torch.long, device=device).unsqueeze(0).expand(b, -1)
|
|
s_freqs_cis = precompute_freqs_cis(self.head_dim, position_ids, theta=self.rope_theta)
|
|
|
|
position_ids = torch.cat([
|
|
torch.zeros([b, s], dtype=torch.long, device=device),
|
|
position_ids,
|
|
], dim=1)
|
|
d_freqs_cis = precompute_freqs_cis(self.head_dim, position_ids, theta=self.rope_theta)
|
|
|
|
if kv_cache is not None and decode:
|
|
embed = embed[:, curr_pos_id, :]
|
|
|
|
h = embed
|
|
c = cond
|
|
layer_idx = 0
|
|
for block in self.transformer.dual_blocks:
|
|
h, c = block(
|
|
h, c=c, freqs_cis=d_freqs_cis, attn_mask=attn_mask, is_causal=True,
|
|
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
|
|
curr_pos_id=curr_pos_id + s if curr_pos_id is not None else None,
|
|
decode=decode,
|
|
)
|
|
layer_idx += 1
|
|
for block in self.transformer.single_blocks:
|
|
h = block(
|
|
h, freqs_cis=s_freqs_cis, attn_mask=None, is_causal=True,
|
|
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
|
|
curr_pos_id=curr_pos_id, decode=decode,
|
|
)
|
|
layer_idx += 1
|
|
|
|
h = self.transformer.ln_f(h)
|
|
return self.lm_head(h)
|