mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 17:59:54 +08:00
Cube3D: reuse shared Flux RoPE (comfy-kitchen optimized kernel)
Replace cube's bespoke complex-number RoPE (torch.polar / view_as_complex) with
ComfyUI's shared Flux rotary embedding (comfy.ldm.flux.math):
* precompute_freqs_cis now returns Flux's real rotation freqs via rope().
* apply_rotary_emb applies them via apply_rope1, which at inference dispatches to
comfy-kitchen's optimized apply_rope kernel (comfy.quant_ops.ck). q and k are
still rotated separately to preserve the decode-time position asymmetry.
The pairing convention (adjacent dims) and rotation math are identical, so token
outputs are unchanged. The only numerical difference is that rope() computes the
rotation angles in fp64 before casting to fp32 (cube's original used fp32), so output
now matches upstream to fp32 rounding (~1e-6 on rotated q/k in a standalone check)
rather than bit-for-bit. Greedy argmax token selection is unaffected.
Deviation note: this is a deliberate, documented divergence from a strict upstream
port, taken to gain the shared optimized kernel. Needs GPU parity re-validation on the
2x4090 box (kosin-X570-AORUS-ULTRA) before merge.
Co-authored-by: Amp <amp@ampcode.com>
Amp-Thread-ID: https://ampcode.com/threads/T-019f013b-5892-71b9-af6b-c2ef28c67d2b
This commit is contained in:
parent
e7f99168ae
commit
94bcb5701e
@ -8,12 +8,15 @@ This is an autoregressive transformer over discrete VQ shape tokens, conditioned
|
||||
CLIP text embeddings. It is NOT a diffusion model; it is driven by the dedicated
|
||||
`sample_cube` sampler (see comfy/k_diffusion/sampling.py), not KSampler.
|
||||
|
||||
The forward pass is kept faithful to upstream so token IDs match bit-for-bit:
|
||||
The forward pass is kept faithful to upstream so token IDs match:
|
||||
* rope_theta = 10000
|
||||
* per-head RMSNorm on Q and K
|
||||
* dual-stream (MM-DiT style) joint attention; last dual block is cond_pre_only
|
||||
* two separate RoPE frequency tensors (dual blocks offset cond tokens by S)
|
||||
* SwiGLU MLP, non-affine LayerNorm upcast to fp32
|
||||
|
||||
RoPE reuses ComfyUI's shared Flux rotary embedding (comfy.ldm.flux.math) so it
|
||||
benefits from comfy-kitchen's optimized apply_rope kernel; see the RoPE section below.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
@ -22,6 +25,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ldm.flux.math
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Norms (faithful to cube3d/model/transformers/norm.py)
|
||||
@ -55,23 +60,37 @@ class CubeRMSNorm(nn.Module):
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RoPE (faithful to cube3d/model/transformers/rope.py)
|
||||
# RoPE
|
||||
#
|
||||
# Reuses ComfyUI's shared Flux rotary embedding (comfy.ldm.flux.math):
|
||||
# * rope() builds the real-valued rotation freqs (cos/-sin/sin/cos), shaped
|
||||
# (B, L, head_dim/2, 2, 2);
|
||||
# * apply_rope1() applies them and, at inference, dispatches to comfy-kitchen's
|
||||
# optimized apply_rope kernel (comfy.quant_ops.ck).
|
||||
# This replaces cube's original complex-number RoPE (torch.polar / view_as_complex),
|
||||
# which was numerically equivalent but bypassed the kernel. The pairing convention is
|
||||
# identical (adjacent dims), so the rotation math is the same; the only difference is
|
||||
# rope() computes the angles in fp64 before casting to fp32, so outputs match upstream
|
||||
# to fp32 rounding rather than bit-for-bit.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def apply_rotary_emb(x, freqs_cis, curr_pos_id=None):
|
||||
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
if curr_pos_id is None:
|
||||
freqs_cis = freqs_cis[:, -x.shape[2]:].unsqueeze(1)
|
||||
else:
|
||||
freqs_cis = freqs_cis[:, curr_pos_id, :].unsqueeze(1)
|
||||
y = torch.view_as_real(x_ * freqs_cis).flatten(3)
|
||||
return y.type_as(x)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim, t, theta=10000.0):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=t.device) / dim))
|
||||
freqs = torch.outer(t.contiguous().view(-1), freqs).reshape(*t.shape, -1)
|
||||
return torch.polar(torch.ones_like(freqs), freqs)
|
||||
# t: (B, L) integer position ids. Returns Flux-style real rotation freqs shaped
|
||||
# (B, L, dim/2, 2, 2) for comfy.ldm.flux.math.apply_rope1.
|
||||
return comfy.ldm.flux.math.rope(t, dim, theta)
|
||||
|
||||
|
||||
def apply_rotary_emb(x, freqs_cis, curr_pos_id=None):
|
||||
# x: (B, num_heads, L, head_dim). Select the rotation freqs for x's positions, add
|
||||
# the head-broadcast axis, then apply via the shared Flux/comfy-kitchen op. q (the
|
||||
# new token[s]) and k (the full sequence) are rotated separately because their
|
||||
# lengths/positions differ during decode.
|
||||
if curr_pos_id is None:
|
||||
freqs_cis = freqs_cis[:, -x.shape[2]:]
|
||||
else:
|
||||
freqs_cis = freqs_cis[:, curr_pos_id]
|
||||
freqs_cis = freqs_cis.unsqueeze(1)
|
||||
return comfy.ldm.flux.math.apply_rope1(x, freqs_cis)
|
||||
|
||||
|
||||
def sdpa_with_rope(q, k, v, freqs_cis, attn_mask=None, curr_pos_id=None, is_causal=False):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user