Use Flux RoPE

This commit is contained in:
kijai 2026-05-23 22:50:19 +03:00
parent 2aba5bafca
commit b429028ad2

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -10,6 +10,8 @@ import torch.nn.functional as F
import comfy.ldm.flux.layers import comfy.ldm.flux.layers
import comfy.patcher_extension import comfy.patcher_extension
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
@ -17,96 +19,39 @@ def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor:
return comfy.ldm.flux.layers.timestep_embedding(t, dim) return comfy.ldm.flux.layers.timestep_embedding(t, dim)
def apply_rotary_emb_lens(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: def _lens_position_ids(
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) frame: int, height: int, width: int, text_seq_len: int,
freqs_cis = freqs_cis.unsqueeze(1) scale_rope: bool = True, device=None,
x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3) ) -> torch.Tensor:
return x_out.type_as(x) """Lens axial (frame, h, w) position ids for joint image + text sequence.
With ``scale_rope=True`` h/w are centered around 0 (negative + positive
halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``;
caller adds a batch dim for ``EmbedND``.
"""
if scale_rope:
h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device),
torch.arange(0, height // 2, device=device)])
w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device),
torch.arange(0, width // 2, device=device)])
text_start = max(height // 2, width // 2)
else:
h_pos = torch.arange(height, device=device)
w_pos = torch.arange(width, device=device)
text_start = max(height, width)
class LensEmbedRope(nn.Module): f_pos = torch.arange(frame, device=device)
"""Frame/H/W axial RoPE shared between image and text streams.""" img_ids = torch.zeros(frame, height, width, 3, device=device)
img_ids[..., 0] = f_pos[:, None, None]
img_ids[..., 1] = h_pos[None, :, None]
img_ids[..., 2] = w_pos[None, None, :]
img_ids = img_ids.reshape(-1, 3)
def __init__(self, theta: int = 10000, axes_dim=(8, 28, 28), scale_rope: bool = True) -> None: # Text positions replicate across all 3 axes (matches original packing).
super().__init__() txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float()
self.theta = theta txt_ids = txt_pos[:, None].expand(text_seq_len, 3)
self.axes_dim = list(axes_dim)
self.scale_rope = scale_rope
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
# Plain attributes (not buffers): register_buffer strips complex imag.
self.pos_freqs = torch.cat(
[self._rope_params(pos_index, d, theta) for d in self.axes_dim], dim=1
)
self.neg_freqs = torch.cat(
[self._rope_params(neg_index, d, theta) for d in self.axes_dim], dim=1
)
self.rope_cache: Dict[str, torch.Tensor] = {}
@staticmethod return torch.cat([img_ids, txt_ids], dim=0)
def _rope_params(index: torch.Tensor, dim: int, theta: int = 10000) -> torch.Tensor:
assert dim % 2 == 0
freqs = torch.outer(
index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).float().div(dim))
)
return torch.polar(torch.ones_like(freqs), freqs)
def forward(
self,
video_fhw: Union[List[Tuple[int, int, int]], Tuple[int, int, int]],
txt_seq_lens: Union[List[int], int],
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]
if not isinstance(txt_seq_lens, list):
txt_seq_lens = [txt_seq_lens]
assert len(video_fhw) == 1, "video_fhw must have length 1"
vid_freqs = []
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}_{device}"
video_freq = self.rope_cache.get(rope_key)
if video_freq is None:
video_freq = self._compute_video_freqs(frame, height, width, idx=0).to(device)
self.rope_cache[rope_key] = video_freq
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
vid_freqs.append(video_freq)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
return torch.cat(vid_freqs, dim=0), txt_freqs
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([d // 2 for d in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([d // 2 for d in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat(
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
).view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat(
[freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0
).view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
class _TimestepEmbedder(nn.Module): class _TimestepEmbedder(nn.Module):
@ -184,7 +129,7 @@ class LensJointAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor], freqs_cis: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None, transformer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@ -207,19 +152,6 @@ class LensJointAttention(nn.Module):
txt_v = txt_v.contiguous() txt_v = txt_v.contiguous()
del txt_qkv del txt_qkv
img_freqs, txt_freqs = image_rotary_emb
if img_freqs.shape[0] < seq_img:
raise ValueError(f"Image RoPE length {img_freqs.shape[0]} < {seq_img}")
img_freqs = img_freqs[:seq_img]
img_q = apply_rotary_emb_lens(img_q, img_freqs)
img_k = apply_rotary_emb_lens(img_k, img_freqs)
if seq_txt > 0:
if txt_freqs.shape[0] < seq_txt:
raise ValueError(f"Text RoPE length {txt_freqs.shape[0]} < {seq_txt}")
txt_freqs = txt_freqs[:seq_txt]
txt_q = apply_rotary_emb_lens(txt_q, txt_freqs)
txt_k = apply_rotary_emb_lens(txt_k, txt_freqs)
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks # [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2) q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
del img_q, txt_q del img_q, txt_q
@ -228,6 +160,8 @@ class LensJointAttention(nn.Module):
v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2) v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
del img_v, txt_v del img_v, txt_v
q, k = apply_rope(q, k, freqs_cis)
if attention_mask is not None: if attention_mask is not None:
expected = (bsz, 1, 1, seq_img + seq_txt) expected = (bsz, 1, 1, seq_img + seq_txt)
if attention_mask.shape != expected: if attention_mask.shape != expected:
@ -308,7 +242,7 @@ class LensTransformerBlock(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor], freqs_cis: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None, transformer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@ -321,7 +255,7 @@ class LensTransformerBlock(nn.Module):
img_attn, txt_attn = self.attn( img_attn, txt_attn = self.attn(
hidden_states=img_modulated, hidden_states=img_modulated,
encoder_hidden_states=txt_modulated, encoder_hidden_states=txt_modulated,
image_rotary_emb=image_rotary_emb, freqs_cis=freqs_cis,
attention_mask=attention_mask, attention_mask=attention_mask,
transformer_options=transformer_options, transformer_options=transformer_options,
) )
@ -391,7 +325,7 @@ class LensTransformer2DModel(nn.Module):
self.selected_layer_index = list(selected_layer_index) self.selected_layer_index = list(selected_layer_index)
self.dtype = dtype self.dtype = dtype
self.pos_embed = LensEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
self.time_text_embed = LensTimestepProjEmbeddings( self.time_text_embed = LensTimestepProjEmbeddings(
embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
) )
@ -502,7 +436,8 @@ class LensTransformer2DModel(nn.Module):
encoder_hidden_states = out["txt"] encoder_hidden_states = out["txt"]
temb = self.time_text_embed(timestep, hidden_states) temb = self.time_text_embed(timestep, hidden_states)
image_rotary_emb = self.pos_embed([(1, h, w)], [text_seq_len], device=hidden_states.device) ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0)
freqs_cis = self.pos_embed(ids)
transformer_options["total_blocks"] = len(self.transformer_blocks) transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double" transformer_options["block_type"] = "double"
@ -515,7 +450,7 @@ class LensTransformer2DModel(nn.Module):
hidden_states=args["img"], hidden_states=args["img"],
encoder_hidden_states=args["txt"], encoder_hidden_states=args["txt"],
temb=args["vec"], temb=args["vec"],
image_rotary_emb=args["pe"], freqs_cis=args["pe"],
attention_mask=args.get("attn_mask"), attention_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"), transformer_options=args.get("transformer_options"),
) )
@ -525,7 +460,7 @@ class LensTransformer2DModel(nn.Module):
"img": hidden_states, "img": hidden_states,
"txt": encoder_hidden_states, "txt": encoder_hidden_states,
"vec": temb, "vec": temb,
"pe": image_rotary_emb, "pe": freqs_cis,
"attn_mask": joint_mask, "attn_mask": joint_mask,
"transformer_options": transformer_options, "transformer_options": transformer_options,
}, },
@ -538,7 +473,7 @@ class LensTransformer2DModel(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, freqs_cis=freqs_cis,
attention_mask=joint_mask, attention_mask=joint_mask,
transformer_options=transformer_options, transformer_options=transformer_options,
) )