From 94bcb5701ee760404ceaf11ec308a2a03d95b242 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 25 Jun 2026 18:15:15 -0700 Subject: [PATCH] Cube3D: reuse shared Flux RoPE (comfy-kitchen optimized kernel) Replace cube's bespoke complex-number RoPE (torch.polar / view_as_complex) with ComfyUI's shared Flux rotary embedding (comfy.ldm.flux.math): * precompute_freqs_cis now returns Flux's real rotation freqs via rope(). * apply_rotary_emb applies them via apply_rope1, which at inference dispatches to comfy-kitchen's optimized apply_rope kernel (comfy.quant_ops.ck). q and k are still rotated separately to preserve the decode-time position asymmetry. The pairing convention (adjacent dims) and rotation math are identical, so token outputs are unchanged. The only numerical difference is that rope() computes the rotation angles in fp64 before casting to fp32 (cube's original used fp32), so output now matches upstream to fp32 rounding (~1e-6 on rotated q/k in a standalone check) rather than bit-for-bit. Greedy argmax token selection is unaffected. Deviation note: this is a deliberate, documented divergence from a strict upstream port, taken to gain the shared optimized kernel. Needs GPU parity re-validation on the 2x4090 box (kosin-X570-AORUS-ULTRA) before merge. Co-authored-by: Amp Amp-Thread-ID: https://ampcode.com/threads/T-019f013b-5892-71b9-af6b-c2ef28c67d2b --- comfy/ldm/cube/gpt.py | 49 ++++++++++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/comfy/ldm/cube/gpt.py b/comfy/ldm/cube/gpt.py index 421648a8d..6dd8e2680 100644 --- a/comfy/ldm/cube/gpt.py +++ b/comfy/ldm/cube/gpt.py @@ -8,12 +8,15 @@ This is an autoregressive transformer over discrete VQ shape tokens, conditioned CLIP text embeddings. It is NOT a diffusion model; it is driven by the dedicated `sample_cube` sampler (see comfy/k_diffusion/sampling.py), not KSampler. -The forward pass is kept faithful to upstream so token IDs match bit-for-bit: +The forward pass is kept faithful to upstream so token IDs match: * rope_theta = 10000 * per-head RMSNorm on Q and K * dual-stream (MM-DiT style) joint attention; last dual block is cond_pre_only * two separate RoPE frequency tensors (dual blocks offset cond tokens by S) * SwiGLU MLP, non-affine LayerNorm upcast to fp32 + +RoPE reuses ComfyUI's shared Flux rotary embedding (comfy.ldm.flux.math) so it +benefits from comfy-kitchen's optimized apply_rope kernel; see the RoPE section below. """ from typing import Optional @@ -22,6 +25,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +import comfy.ldm.flux.math + # --------------------------------------------------------------------------- # Norms (faithful to cube3d/model/transformers/norm.py) @@ -55,23 +60,37 @@ class CubeRMSNorm(nn.Module): # --------------------------------------------------------------------------- -# RoPE (faithful to cube3d/model/transformers/rope.py) +# RoPE +# +# Reuses ComfyUI's shared Flux rotary embedding (comfy.ldm.flux.math): +# * rope() builds the real-valued rotation freqs (cos/-sin/sin/cos), shaped +# (B, L, head_dim/2, 2, 2); +# * apply_rope1() applies them and, at inference, dispatches to comfy-kitchen's +# optimized apply_rope kernel (comfy.quant_ops.ck). +# This replaces cube's original complex-number RoPE (torch.polar / view_as_complex), +# which was numerically equivalent but bypassed the kernel. The pairing convention is +# identical (adjacent dims), so the rotation math is the same; the only difference is +# rope() computes the angles in fp64 before casting to fp32, so outputs match upstream +# to fp32 rounding rather than bit-for-bit. # --------------------------------------------------------------------------- -def apply_rotary_emb(x, freqs_cis, curr_pos_id=None): - x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - if curr_pos_id is None: - freqs_cis = freqs_cis[:, -x.shape[2]:].unsqueeze(1) - else: - freqs_cis = freqs_cis[:, curr_pos_id, :].unsqueeze(1) - y = torch.view_as_real(x_ * freqs_cis).flatten(3) - return y.type_as(x) - - def precompute_freqs_cis(dim, t, theta=10000.0): - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=t.device) / dim)) - freqs = torch.outer(t.contiguous().view(-1), freqs).reshape(*t.shape, -1) - return torch.polar(torch.ones_like(freqs), freqs) + # t: (B, L) integer position ids. Returns Flux-style real rotation freqs shaped + # (B, L, dim/2, 2, 2) for comfy.ldm.flux.math.apply_rope1. + return comfy.ldm.flux.math.rope(t, dim, theta) + + +def apply_rotary_emb(x, freqs_cis, curr_pos_id=None): + # x: (B, num_heads, L, head_dim). Select the rotation freqs for x's positions, add + # the head-broadcast axis, then apply via the shared Flux/comfy-kitchen op. q (the + # new token[s]) and k (the full sequence) are rotated separately because their + # lengths/positions differ during decode. + if curr_pos_id is None: + freqs_cis = freqs_cis[:, -x.shape[2]:] + else: + freqs_cis = freqs_cis[:, curr_pos_id] + freqs_cis = freqs_cis.unsqueeze(1) + return comfy.ldm.flux.math.apply_rope1(x, freqs_cis) def sdpa_with_rope(q, k, v, freqs_cis, attn_mask=None, curr_pos_id=None, is_causal=False):