diff --git a/comfy/ldm/cube/gpt.py b/comfy/ldm/cube/gpt.py index 421648a8d..6dd8e2680 100644 --- a/comfy/ldm/cube/gpt.py +++ b/comfy/ldm/cube/gpt.py @@ -8,12 +8,15 @@ This is an autoregressive transformer over discrete VQ shape tokens, conditioned CLIP text embeddings. It is NOT a diffusion model; it is driven by the dedicated `sample_cube` sampler (see comfy/k_diffusion/sampling.py), not KSampler. -The forward pass is kept faithful to upstream so token IDs match bit-for-bit: +The forward pass is kept faithful to upstream so token IDs match: * rope_theta = 10000 * per-head RMSNorm on Q and K * dual-stream (MM-DiT style) joint attention; last dual block is cond_pre_only * two separate RoPE frequency tensors (dual blocks offset cond tokens by S) * SwiGLU MLP, non-affine LayerNorm upcast to fp32 + +RoPE reuses ComfyUI's shared Flux rotary embedding (comfy.ldm.flux.math) so it +benefits from comfy-kitchen's optimized apply_rope kernel; see the RoPE section below. """ from typing import Optional @@ -22,6 +25,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +import comfy.ldm.flux.math + # --------------------------------------------------------------------------- # Norms (faithful to cube3d/model/transformers/norm.py) @@ -55,23 +60,37 @@ class CubeRMSNorm(nn.Module): # --------------------------------------------------------------------------- -# RoPE (faithful to cube3d/model/transformers/rope.py) +# RoPE +# +# Reuses ComfyUI's shared Flux rotary embedding (comfy.ldm.flux.math): +# * rope() builds the real-valued rotation freqs (cos/-sin/sin/cos), shaped +# (B, L, head_dim/2, 2, 2); +# * apply_rope1() applies them and, at inference, dispatches to comfy-kitchen's +# optimized apply_rope kernel (comfy.quant_ops.ck). +# This replaces cube's original complex-number RoPE (torch.polar / view_as_complex), +# which was numerically equivalent but bypassed the kernel. The pairing convention is +# identical (adjacent dims), so the rotation math is the same; the only difference is +# rope() computes the angles in fp64 before casting to fp32, so outputs match upstream +# to fp32 rounding rather than bit-for-bit. # --------------------------------------------------------------------------- -def apply_rotary_emb(x, freqs_cis, curr_pos_id=None): - x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - if curr_pos_id is None: - freqs_cis = freqs_cis[:, -x.shape[2]:].unsqueeze(1) - else: - freqs_cis = freqs_cis[:, curr_pos_id, :].unsqueeze(1) - y = torch.view_as_real(x_ * freqs_cis).flatten(3) - return y.type_as(x) - - def precompute_freqs_cis(dim, t, theta=10000.0): - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=t.device) / dim)) - freqs = torch.outer(t.contiguous().view(-1), freqs).reshape(*t.shape, -1) - return torch.polar(torch.ones_like(freqs), freqs) + # t: (B, L) integer position ids. Returns Flux-style real rotation freqs shaped + # (B, L, dim/2, 2, 2) for comfy.ldm.flux.math.apply_rope1. + return comfy.ldm.flux.math.rope(t, dim, theta) + + +def apply_rotary_emb(x, freqs_cis, curr_pos_id=None): + # x: (B, num_heads, L, head_dim). Select the rotation freqs for x's positions, add + # the head-broadcast axis, then apply via the shared Flux/comfy-kitchen op. q (the + # new token[s]) and k (the full sequence) are rotated separately because their + # lengths/positions differ during decode. + if curr_pos_id is None: + freqs_cis = freqs_cis[:, -x.shape[2]:] + else: + freqs_cis = freqs_cis[:, curr_pos_id] + freqs_cis = freqs_cis.unsqueeze(1) + return comfy.ldm.flux.math.apply_rope1(x, freqs_cis) def sdpa_with_rope(q, k, v, freqs_cis, attn_mask=None, curr_pos_id=None, is_causal=False):