From b429028ad21eeab25bfe963dffab00bf2335cdb1 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 23 May 2026 22:50:19 +0300 Subject: [PATCH] Use Flux RoPE --- comfy/ldm/lens/model.py | 151 ++++++++++++---------------------------- 1 file changed, 43 insertions(+), 108 deletions(-) diff --git a/comfy/ldm/lens/model.py b/comfy/ldm/lens/model.py index 8099ae26b..7bff7f6af 100644 --- a/comfy/ldm/lens/model.py +++ b/comfy/ldm/lens/model.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn @@ -10,6 +10,8 @@ import torch.nn.functional as F import comfy.ldm.flux.layers import comfy.patcher_extension +from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.math import apply_rope from comfy.ldm.modules.attention import optimized_attention @@ -17,96 +19,39 @@ def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor: return comfy.ldm.flux.layers.timestep_embedding(t, dim) -def apply_rotary_emb_lens(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.unsqueeze(1) - x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3) - return x_out.type_as(x) +def _lens_position_ids( + frame: int, height: int, width: int, text_seq_len: int, + scale_rope: bool = True, device=None, +) -> torch.Tensor: + """Lens axial (frame, h, w) position ids for joint image + text sequence. + With ``scale_rope=True`` h/w are centered around 0 (negative + positive + halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``; + caller adds a batch dim for ``EmbedND``. + """ + if scale_rope: + h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device), + torch.arange(0, height // 2, device=device)]) + w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device), + torch.arange(0, width // 2, device=device)]) + text_start = max(height // 2, width // 2) + else: + h_pos = torch.arange(height, device=device) + w_pos = torch.arange(width, device=device) + text_start = max(height, width) -class LensEmbedRope(nn.Module): - """Frame/H/W axial RoPE shared between image and text streams.""" + f_pos = torch.arange(frame, device=device) + img_ids = torch.zeros(frame, height, width, 3, device=device) + img_ids[..., 0] = f_pos[:, None, None] + img_ids[..., 1] = h_pos[None, :, None] + img_ids[..., 2] = w_pos[None, None, :] + img_ids = img_ids.reshape(-1, 3) - def __init__(self, theta: int = 10000, axes_dim=(8, 28, 28), scale_rope: bool = True) -> None: - super().__init__() - self.theta = theta - self.axes_dim = list(axes_dim) - self.scale_rope = scale_rope - pos_index = torch.arange(4096) - neg_index = torch.arange(4096).flip(0) * -1 - 1 - # Plain attributes (not buffers): register_buffer strips complex imag. - self.pos_freqs = torch.cat( - [self._rope_params(pos_index, d, theta) for d in self.axes_dim], dim=1 - ) - self.neg_freqs = torch.cat( - [self._rope_params(neg_index, d, theta) for d in self.axes_dim], dim=1 - ) - self.rope_cache: Dict[str, torch.Tensor] = {} + # Text positions replicate across all 3 axes (matches original packing). + txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float() + txt_ids = txt_pos[:, None].expand(text_seq_len, 3) - @staticmethod - def _rope_params(index: torch.Tensor, dim: int, theta: int = 10000) -> torch.Tensor: - assert dim % 2 == 0 - freqs = torch.outer( - index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).float().div(dim)) - ) - return torch.polar(torch.ones_like(freqs), freqs) - - def forward( - self, - video_fhw: Union[List[Tuple[int, int, int]], Tuple[int, int, int]], - txt_seq_lens: Union[List[int], int], - device: torch.device, - ) -> Tuple[torch.Tensor, torch.Tensor]: - if self.pos_freqs.device != device: - self.pos_freqs = self.pos_freqs.to(device) - self.neg_freqs = self.neg_freqs.to(device) - - if isinstance(video_fhw, list): - video_fhw = video_fhw[0] - if not isinstance(video_fhw, list): - video_fhw = [video_fhw] - if not isinstance(txt_seq_lens, list): - txt_seq_lens = [txt_seq_lens] - assert len(video_fhw) == 1, "video_fhw must have length 1" - - vid_freqs = [] - max_vid_index = 0 - for idx, fhw in enumerate(video_fhw): - frame, height, width = fhw - rope_key = f"{idx}_{height}_{width}_{device}" - video_freq = self.rope_cache.get(rope_key) - if video_freq is None: - video_freq = self._compute_video_freqs(frame, height, width, idx=0).to(device) - self.rope_cache[rope_key] = video_freq - if self.scale_rope: - max_vid_index = max(height // 2, width // 2, max_vid_index) - else: - max_vid_index = max(height, width, max_vid_index) - vid_freqs.append(video_freq) - - max_len = max(txt_seq_lens) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] - return torch.cat(vid_freqs, dim=0), txt_freqs - - def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor: - seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([d // 2 for d in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([d // 2 for d in self.axes_dim], dim=1) - - freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) - if self.scale_rope: - freqs_height = torch.cat( - [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0 - ).view(1, height, 1, -1).expand(frame, height, width, -1) - freqs_width = torch.cat( - [freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0 - ).view(1, 1, width, -1).expand(frame, height, width, -1) - else: - freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) - freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) - - freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) - return freqs.clone().contiguous() + return torch.cat([img_ids, txt_ids], dim=0) class _TimestepEmbedder(nn.Module): @@ -184,7 +129,7 @@ class LensJointAttention(nn.Module): self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - image_rotary_emb: Tuple[torch.Tensor, torch.Tensor], + freqs_cis: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, transformer_options: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -207,19 +152,6 @@ class LensJointAttention(nn.Module): txt_v = txt_v.contiguous() del txt_qkv - img_freqs, txt_freqs = image_rotary_emb - if img_freqs.shape[0] < seq_img: - raise ValueError(f"Image RoPE length {img_freqs.shape[0]} < {seq_img}") - img_freqs = img_freqs[:seq_img] - img_q = apply_rotary_emb_lens(img_q, img_freqs) - img_k = apply_rotary_emb_lens(img_k, img_freqs) - if seq_txt > 0: - if txt_freqs.shape[0] < seq_txt: - raise ValueError(f"Text RoPE length {txt_freqs.shape[0]} < {seq_txt}") - txt_freqs = txt_freqs[:seq_txt] - txt_q = apply_rotary_emb_lens(txt_q, txt_freqs) - txt_k = apply_rotary_emb_lens(txt_k, txt_freqs) - # [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2) del img_q, txt_q @@ -228,6 +160,8 @@ class LensJointAttention(nn.Module): v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2) del img_v, txt_v + q, k = apply_rope(q, k, freqs_cis) + if attention_mask is not None: expected = (bsz, 1, 1, seq_img + seq_txt) if attention_mask.shape != expected: @@ -308,7 +242,7 @@ class LensTransformerBlock(nn.Module): hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, - image_rotary_emb: Tuple[torch.Tensor, torch.Tensor], + freqs_cis: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, transformer_options: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -321,7 +255,7 @@ class LensTransformerBlock(nn.Module): img_attn, txt_attn = self.attn( hidden_states=img_modulated, encoder_hidden_states=txt_modulated, - image_rotary_emb=image_rotary_emb, + freqs_cis=freqs_cis, attention_mask=attention_mask, transformer_options=transformer_options, ) @@ -391,7 +325,7 @@ class LensTransformer2DModel(nn.Module): self.selected_layer_index = list(selected_layer_index) self.dtype = dtype - self.pos_embed = LensEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) + self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) self.time_text_embed = LensTimestepProjEmbeddings( embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations ) @@ -502,7 +436,8 @@ class LensTransformer2DModel(nn.Module): encoder_hidden_states = out["txt"] temb = self.time_text_embed(timestep, hidden_states) - image_rotary_emb = self.pos_embed([(1, h, w)], [text_seq_len], device=hidden_states.device) + ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0) + freqs_cis = self.pos_embed(ids) transformer_options["total_blocks"] = len(self.transformer_blocks) transformer_options["block_type"] = "double" @@ -515,7 +450,7 @@ class LensTransformer2DModel(nn.Module): hidden_states=args["img"], encoder_hidden_states=args["txt"], temb=args["vec"], - image_rotary_emb=args["pe"], + freqs_cis=args["pe"], attention_mask=args.get("attn_mask"), transformer_options=args.get("transformer_options"), ) @@ -525,7 +460,7 @@ class LensTransformer2DModel(nn.Module): "img": hidden_states, "txt": encoder_hidden_states, "vec": temb, - "pe": image_rotary_emb, + "pe": freqs_cis, "attn_mask": joint_mask, "transformer_options": transformer_options, }, @@ -538,7 +473,7 @@ class LensTransformer2DModel(nn.Module): hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, - image_rotary_emb=image_rotary_emb, + freqs_cis=freqs_cis, attention_mask=joint_mask, transformer_options=transformer_options, )