Cube3D: reuse shared Flux RoPE (comfy-kitchen optimized kernel)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

Replace cube's bespoke complex-number RoPE (torch.polar / view_as_complex) with
ComfyUI's shared Flux rotary embedding (comfy.ldm.flux.math):
  * precompute_freqs_cis now returns Flux's real rotation freqs via rope().
  * apply_rotary_emb applies them via apply_rope1, which at inference dispatches to
    comfy-kitchen's optimized apply_rope kernel (comfy.quant_ops.ck). q and k are
    still rotated separately to preserve the decode-time position asymmetry.

The pairing convention (adjacent dims) and rotation math are identical, so token
outputs are unchanged. The only numerical difference is that rope() computes the
rotation angles in fp64 before casting to fp32 (cube's original used fp32), so output
now matches upstream to fp32 rounding (~1e-6 on rotated q/k in a standalone check)
rather than bit-for-bit. Greedy argmax token selection is unaffected.

Deviation note: this is a deliberate, documented divergence from a strict upstream
port, taken to gain the shared optimized kernel. Needs GPU parity re-validation on the
2x4090 box (kosin-X570-AORUS-ULTRA) before merge.

Co-authored-by: Amp <amp@ampcode.com>
Amp-Thread-ID: https://ampcode.com/threads/T-019f013b-5892-71b9-af6b-c2ef28c67d2b
This commit is contained in:
Jedrzej Kosinski 2026-06-25 18:15:15 -07:00
parent e7f99168ae
commit 94bcb5701e

View File

@ -8,12 +8,15 @@ This is an autoregressive transformer over discrete VQ shape tokens, conditioned
CLIP text embeddings. It is NOT a diffusion model; it is driven by the dedicated
`sample_cube` sampler (see comfy/k_diffusion/sampling.py), not KSampler.
The forward pass is kept faithful to upstream so token IDs match bit-for-bit:
The forward pass is kept faithful to upstream so token IDs match:
* rope_theta = 10000
* per-head RMSNorm on Q and K
* dual-stream (MM-DiT style) joint attention; last dual block is cond_pre_only
* two separate RoPE frequency tensors (dual blocks offset cond tokens by S)
* SwiGLU MLP, non-affine LayerNorm upcast to fp32
RoPE reuses ComfyUI's shared Flux rotary embedding (comfy.ldm.flux.math) so it
benefits from comfy-kitchen's optimized apply_rope kernel; see the RoPE section below.
"""
from typing import Optional
@ -22,6 +25,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ldm.flux.math
# ---------------------------------------------------------------------------
# Norms (faithful to cube3d/model/transformers/norm.py)
@ -55,23 +60,37 @@ class CubeRMSNorm(nn.Module):
# ---------------------------------------------------------------------------
# RoPE (faithful to cube3d/model/transformers/rope.py)
# RoPE
#
# Reuses ComfyUI's shared Flux rotary embedding (comfy.ldm.flux.math):
# * rope() builds the real-valued rotation freqs (cos/-sin/sin/cos), shaped
# (B, L, head_dim/2, 2, 2);
# * apply_rope1() applies them and, at inference, dispatches to comfy-kitchen's
# optimized apply_rope kernel (comfy.quant_ops.ck).
# This replaces cube's original complex-number RoPE (torch.polar / view_as_complex),
# which was numerically equivalent but bypassed the kernel. The pairing convention is
# identical (adjacent dims), so the rotation math is the same; the only difference is
# rope() computes the angles in fp64 before casting to fp32, so outputs match upstream
# to fp32 rounding rather than bit-for-bit.
# ---------------------------------------------------------------------------
def apply_rotary_emb(x, freqs_cis, curr_pos_id=None):
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
if curr_pos_id is None:
freqs_cis = freqs_cis[:, -x.shape[2]:].unsqueeze(1)
else:
freqs_cis = freqs_cis[:, curr_pos_id, :].unsqueeze(1)
y = torch.view_as_real(x_ * freqs_cis).flatten(3)
return y.type_as(x)
def precompute_freqs_cis(dim, t, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=t.device) / dim))
freqs = torch.outer(t.contiguous().view(-1), freqs).reshape(*t.shape, -1)
return torch.polar(torch.ones_like(freqs), freqs)
# t: (B, L) integer position ids. Returns Flux-style real rotation freqs shaped
# (B, L, dim/2, 2, 2) for comfy.ldm.flux.math.apply_rope1.
return comfy.ldm.flux.math.rope(t, dim, theta)
def apply_rotary_emb(x, freqs_cis, curr_pos_id=None):
# x: (B, num_heads, L, head_dim). Select the rotation freqs for x's positions, add
# the head-broadcast axis, then apply via the shared Flux/comfy-kitchen op. q (the
# new token[s]) and k (the full sequence) are rotated separately because their
# lengths/positions differ during decode.
if curr_pos_id is None:
freqs_cis = freqs_cis[:, -x.shape[2]:]
else:
freqs_cis = freqs_cis[:, curr_pos_id]
freqs_cis = freqs_cis.unsqueeze(1)
return comfy.ldm.flux.math.apply_rope1(x, freqs_cis)
def sdpa_with_rope(q, k, v, freqs_cis, attn_mask=None, curr_pos_id=None, is_causal=False):