Use Flux RoPE

This commit is contained in:
kijai 2026-05-23 22:50:19 +03:00
parent 2aba5bafca
commit b429028ad2

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -10,6 +10,8 @@ import torch.nn.functional as F
import comfy.ldm.flux.layers import comfy.ldm.flux.layers
import comfy.patcher_extension import comfy.patcher_extension
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
@ -17,96 +19,39 @@ def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor:
return comfy.ldm.flux.layers.timestep_embedding(t, dim) return comfy.ldm.flux.layers.timestep_embedding(t, dim)
def apply_rotary_emb_lens(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: def _lens_position_ids(
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) frame: int, height: int, width: int, text_seq_len: int,
freqs_cis = freqs_cis.unsqueeze(1) scale_rope: bool = True, device=None,
x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3) ) -> torch.Tensor:
return x_out.type_as(x) """Lens axial (frame, h, w) position ids for joint image + text sequence.
With ``scale_rope=True`` h/w are centered around 0 (negative + positive
class LensEmbedRope(nn.Module): halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``;
"""Frame/H/W axial RoPE shared between image and text streams.""" caller adds a batch dim for ``EmbedND``.
"""
def __init__(self, theta: int = 10000, axes_dim=(8, 28, 28), scale_rope: bool = True) -> None: if scale_rope:
super().__init__() h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device),
self.theta = theta torch.arange(0, height // 2, device=device)])
self.axes_dim = list(axes_dim) w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device),
self.scale_rope = scale_rope torch.arange(0, width // 2, device=device)])
pos_index = torch.arange(4096) text_start = max(height // 2, width // 2)
neg_index = torch.arange(4096).flip(0) * -1 - 1
# Plain attributes (not buffers): register_buffer strips complex imag.
self.pos_freqs = torch.cat(
[self._rope_params(pos_index, d, theta) for d in self.axes_dim], dim=1
)
self.neg_freqs = torch.cat(
[self._rope_params(neg_index, d, theta) for d in self.axes_dim], dim=1
)
self.rope_cache: Dict[str, torch.Tensor] = {}
@staticmethod
def _rope_params(index: torch.Tensor, dim: int, theta: int = 10000) -> torch.Tensor:
assert dim % 2 == 0
freqs = torch.outer(
index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).float().div(dim))
)
return torch.polar(torch.ones_like(freqs), freqs)
def forward(
self,
video_fhw: Union[List[Tuple[int, int, int]], Tuple[int, int, int]],
txt_seq_lens: Union[List[int], int],
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]
if not isinstance(txt_seq_lens, list):
txt_seq_lens = [txt_seq_lens]
assert len(video_fhw) == 1, "video_fhw must have length 1"
vid_freqs = []
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}_{device}"
video_freq = self.rope_cache.get(rope_key)
if video_freq is None:
video_freq = self._compute_video_freqs(frame, height, width, idx=0).to(device)
self.rope_cache[rope_key] = video_freq
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else: else:
max_vid_index = max(height, width, max_vid_index) h_pos = torch.arange(height, device=device)
vid_freqs.append(video_freq) w_pos = torch.arange(width, device=device)
text_start = max(height, width)
max_len = max(txt_seq_lens) f_pos = torch.arange(frame, device=device)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] img_ids = torch.zeros(frame, height, width, 3, device=device)
return torch.cat(vid_freqs, dim=0), txt_freqs img_ids[..., 0] = f_pos[:, None, None]
img_ids[..., 1] = h_pos[None, :, None]
img_ids[..., 2] = w_pos[None, None, :]
img_ids = img_ids.reshape(-1, 3)
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor: # Text positions replicate across all 3 axes (matches original packing).
seq_lens = frame * height * width txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float()
freqs_pos = self.pos_freqs.split([d // 2 for d in self.axes_dim], dim=1) txt_ids = txt_pos[:, None].expand(text_seq_len, 3)
freqs_neg = self.neg_freqs.split([d // 2 for d in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) return torch.cat([img_ids, txt_ids], dim=0)
if self.scale_rope:
freqs_height = torch.cat(
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
).view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat(
[freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0
).view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
class _TimestepEmbedder(nn.Module): class _TimestepEmbedder(nn.Module):
@ -184,7 +129,7 @@ class LensJointAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor], freqs_cis: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None, transformer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@ -207,19 +152,6 @@ class LensJointAttention(nn.Module):
txt_v = txt_v.contiguous() txt_v = txt_v.contiguous()
del txt_qkv del txt_qkv
img_freqs, txt_freqs = image_rotary_emb
if img_freqs.shape[0] < seq_img:
raise ValueError(f"Image RoPE length {img_freqs.shape[0]} < {seq_img}")
img_freqs = img_freqs[:seq_img]
img_q = apply_rotary_emb_lens(img_q, img_freqs)
img_k = apply_rotary_emb_lens(img_k, img_freqs)
if seq_txt > 0:
if txt_freqs.shape[0] < seq_txt:
raise ValueError(f"Text RoPE length {txt_freqs.shape[0]} < {seq_txt}")
txt_freqs = txt_freqs[:seq_txt]
txt_q = apply_rotary_emb_lens(txt_q, txt_freqs)
txt_k = apply_rotary_emb_lens(txt_k, txt_freqs)
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks # [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2) q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
del img_q, txt_q del img_q, txt_q
@ -228,6 +160,8 @@ class LensJointAttention(nn.Module):
v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2) v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
del img_v, txt_v del img_v, txt_v
q, k = apply_rope(q, k, freqs_cis)
if attention_mask is not None: if attention_mask is not None:
expected = (bsz, 1, 1, seq_img + seq_txt) expected = (bsz, 1, 1, seq_img + seq_txt)
if attention_mask.shape != expected: if attention_mask.shape != expected:
@ -308,7 +242,7 @@ class LensTransformerBlock(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor], freqs_cis: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None, transformer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@ -321,7 +255,7 @@ class LensTransformerBlock(nn.Module):
img_attn, txt_attn = self.attn( img_attn, txt_attn = self.attn(
hidden_states=img_modulated, hidden_states=img_modulated,
encoder_hidden_states=txt_modulated, encoder_hidden_states=txt_modulated,
image_rotary_emb=image_rotary_emb, freqs_cis=freqs_cis,
attention_mask=attention_mask, attention_mask=attention_mask,
transformer_options=transformer_options, transformer_options=transformer_options,
) )
@ -391,7 +325,7 @@ class LensTransformer2DModel(nn.Module):
self.selected_layer_index = list(selected_layer_index) self.selected_layer_index = list(selected_layer_index)
self.dtype = dtype self.dtype = dtype
self.pos_embed = LensEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
self.time_text_embed = LensTimestepProjEmbeddings( self.time_text_embed = LensTimestepProjEmbeddings(
embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
) )
@ -502,7 +436,8 @@ class LensTransformer2DModel(nn.Module):
encoder_hidden_states = out["txt"] encoder_hidden_states = out["txt"]
temb = self.time_text_embed(timestep, hidden_states) temb = self.time_text_embed(timestep, hidden_states)
image_rotary_emb = self.pos_embed([(1, h, w)], [text_seq_len], device=hidden_states.device) ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0)
freqs_cis = self.pos_embed(ids)
transformer_options["total_blocks"] = len(self.transformer_blocks) transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double" transformer_options["block_type"] = "double"
@ -515,7 +450,7 @@ class LensTransformer2DModel(nn.Module):
hidden_states=args["img"], hidden_states=args["img"],
encoder_hidden_states=args["txt"], encoder_hidden_states=args["txt"],
temb=args["vec"], temb=args["vec"],
image_rotary_emb=args["pe"], freqs_cis=args["pe"],
attention_mask=args.get("attn_mask"), attention_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"), transformer_options=args.get("transformer_options"),
) )
@ -525,7 +460,7 @@ class LensTransformer2DModel(nn.Module):
"img": hidden_states, "img": hidden_states,
"txt": encoder_hidden_states, "txt": encoder_hidden_states,
"vec": temb, "vec": temb,
"pe": image_rotary_emb, "pe": freqs_cis,
"attn_mask": joint_mask, "attn_mask": joint_mask,
"transformer_options": transformer_options, "transformer_options": transformer_options,
}, },
@ -538,7 +473,7 @@ class LensTransformer2DModel(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, freqs_cis=freqs_cis,
attention_mask=joint_mask, attention_mask=joint_mask,
transformer_options=transformer_options, transformer_options=transformer_options,
) )