Use Flux RoPE

This commit is contained in:
kijai 2026-05-23 22:50:19 +03:00
parent 2aba5bafca
commit b429028ad2

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
@ -10,6 +10,8 @@ import torch.nn.functional as F
import comfy.ldm.flux.layers
import comfy.patcher_extension
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.attention import optimized_attention
@ -17,96 +19,39 @@ def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor:
return comfy.ldm.flux.layers.timestep_embedding(t, dim)
def apply_rotary_emb_lens(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(1)
x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3)
return x_out.type_as(x)
def _lens_position_ids(
frame: int, height: int, width: int, text_seq_len: int,
scale_rope: bool = True, device=None,
) -> torch.Tensor:
"""Lens axial (frame, h, w) position ids for joint image + text sequence.
With ``scale_rope=True`` h/w are centered around 0 (negative + positive
halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``;
caller adds a batch dim for ``EmbedND``.
"""
if scale_rope:
h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device),
torch.arange(0, height // 2, device=device)])
w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device),
torch.arange(0, width // 2, device=device)])
text_start = max(height // 2, width // 2)
else:
h_pos = torch.arange(height, device=device)
w_pos = torch.arange(width, device=device)
text_start = max(height, width)
class LensEmbedRope(nn.Module):
"""Frame/H/W axial RoPE shared between image and text streams."""
f_pos = torch.arange(frame, device=device)
img_ids = torch.zeros(frame, height, width, 3, device=device)
img_ids[..., 0] = f_pos[:, None, None]
img_ids[..., 1] = h_pos[None, :, None]
img_ids[..., 2] = w_pos[None, None, :]
img_ids = img_ids.reshape(-1, 3)
def __init__(self, theta: int = 10000, axes_dim=(8, 28, 28), scale_rope: bool = True) -> None:
super().__init__()
self.theta = theta
self.axes_dim = list(axes_dim)
self.scale_rope = scale_rope
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
# Plain attributes (not buffers): register_buffer strips complex imag.
self.pos_freqs = torch.cat(
[self._rope_params(pos_index, d, theta) for d in self.axes_dim], dim=1
)
self.neg_freqs = torch.cat(
[self._rope_params(neg_index, d, theta) for d in self.axes_dim], dim=1
)
self.rope_cache: Dict[str, torch.Tensor] = {}
# Text positions replicate across all 3 axes (matches original packing).
txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float()
txt_ids = txt_pos[:, None].expand(text_seq_len, 3)
@staticmethod
def _rope_params(index: torch.Tensor, dim: int, theta: int = 10000) -> torch.Tensor:
assert dim % 2 == 0
freqs = torch.outer(
index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).float().div(dim))
)
return torch.polar(torch.ones_like(freqs), freqs)
def forward(
self,
video_fhw: Union[List[Tuple[int, int, int]], Tuple[int, int, int]],
txt_seq_lens: Union[List[int], int],
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]
if not isinstance(txt_seq_lens, list):
txt_seq_lens = [txt_seq_lens]
assert len(video_fhw) == 1, "video_fhw must have length 1"
vid_freqs = []
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}_{device}"
video_freq = self.rope_cache.get(rope_key)
if video_freq is None:
video_freq = self._compute_video_freqs(frame, height, width, idx=0).to(device)
self.rope_cache[rope_key] = video_freq
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
vid_freqs.append(video_freq)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
return torch.cat(vid_freqs, dim=0), txt_freqs
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([d // 2 for d in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([d // 2 for d in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat(
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
).view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat(
[freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0
).view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
return torch.cat([img_ids, txt_ids], dim=0)
class _TimestepEmbedder(nn.Module):
@ -184,7 +129,7 @@ class LensJointAttention(nn.Module):
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
freqs_cis: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -207,19 +152,6 @@ class LensJointAttention(nn.Module):
txt_v = txt_v.contiguous()
del txt_qkv
img_freqs, txt_freqs = image_rotary_emb
if img_freqs.shape[0] < seq_img:
raise ValueError(f"Image RoPE length {img_freqs.shape[0]} < {seq_img}")
img_freqs = img_freqs[:seq_img]
img_q = apply_rotary_emb_lens(img_q, img_freqs)
img_k = apply_rotary_emb_lens(img_k, img_freqs)
if seq_txt > 0:
if txt_freqs.shape[0] < seq_txt:
raise ValueError(f"Text RoPE length {txt_freqs.shape[0]} < {seq_txt}")
txt_freqs = txt_freqs[:seq_txt]
txt_q = apply_rotary_emb_lens(txt_q, txt_freqs)
txt_k = apply_rotary_emb_lens(txt_k, txt_freqs)
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
del img_q, txt_q
@ -228,6 +160,8 @@ class LensJointAttention(nn.Module):
v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
del img_v, txt_v
q, k = apply_rope(q, k, freqs_cis)
if attention_mask is not None:
expected = (bsz, 1, 1, seq_img + seq_txt)
if attention_mask.shape != expected:
@ -308,7 +242,7 @@ class LensTransformerBlock(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
freqs_cis: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -321,7 +255,7 @@ class LensTransformerBlock(nn.Module):
img_attn, txt_attn = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=txt_modulated,
image_rotary_emb=image_rotary_emb,
freqs_cis=freqs_cis,
attention_mask=attention_mask,
transformer_options=transformer_options,
)
@ -391,7 +325,7 @@ class LensTransformer2DModel(nn.Module):
self.selected_layer_index = list(selected_layer_index)
self.dtype = dtype
self.pos_embed = LensEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
self.time_text_embed = LensTimestepProjEmbeddings(
embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
)
@ -502,7 +436,8 @@ class LensTransformer2DModel(nn.Module):
encoder_hidden_states = out["txt"]
temb = self.time_text_embed(timestep, hidden_states)
image_rotary_emb = self.pos_embed([(1, h, w)], [text_seq_len], device=hidden_states.device)
ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0)
freqs_cis = self.pos_embed(ids)
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
@ -515,7 +450,7 @@ class LensTransformer2DModel(nn.Module):
hidden_states=args["img"],
encoder_hidden_states=args["txt"],
temb=args["vec"],
image_rotary_emb=args["pe"],
freqs_cis=args["pe"],
attention_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"),
)
@ -525,7 +460,7 @@ class LensTransformer2DModel(nn.Module):
"img": hidden_states,
"txt": encoder_hidden_states,
"vec": temb,
"pe": image_rotary_emb,
"pe": freqs_cis,
"attn_mask": joint_mask,
"transformer_options": transformer_options,
},
@ -538,7 +473,7 @@ class LensTransformer2DModel(nn.Module):
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
freqs_cis=freqs_cis,
attention_mask=joint_mask,
transformer_options=transformer_options,
)