mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
Use Flux RoPE
This commit is contained in:
parent
2aba5bafca
commit
b429028ad2
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -10,6 +10,8 @@ import torch.nn.functional as F
|
||||
|
||||
import comfy.ldm.flux.layers
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
@ -17,96 +19,39 @@ def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor:
|
||||
return comfy.ldm.flux.layers.timestep_embedding(t, dim)
|
||||
|
||||
|
||||
def apply_rotary_emb_lens(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(1)
|
||||
x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x)
|
||||
def _lens_position_ids(
|
||||
frame: int, height: int, width: int, text_seq_len: int,
|
||||
scale_rope: bool = True, device=None,
|
||||
) -> torch.Tensor:
|
||||
"""Lens axial (frame, h, w) position ids for joint image + text sequence.
|
||||
|
||||
With ``scale_rope=True`` h/w are centered around 0 (negative + positive
|
||||
halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``;
|
||||
caller adds a batch dim for ``EmbedND``.
|
||||
"""
|
||||
if scale_rope:
|
||||
h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device),
|
||||
torch.arange(0, height // 2, device=device)])
|
||||
w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device),
|
||||
torch.arange(0, width // 2, device=device)])
|
||||
text_start = max(height // 2, width // 2)
|
||||
else:
|
||||
h_pos = torch.arange(height, device=device)
|
||||
w_pos = torch.arange(width, device=device)
|
||||
text_start = max(height, width)
|
||||
|
||||
class LensEmbedRope(nn.Module):
|
||||
"""Frame/H/W axial RoPE shared between image and text streams."""
|
||||
f_pos = torch.arange(frame, device=device)
|
||||
img_ids = torch.zeros(frame, height, width, 3, device=device)
|
||||
img_ids[..., 0] = f_pos[:, None, None]
|
||||
img_ids[..., 1] = h_pos[None, :, None]
|
||||
img_ids[..., 2] = w_pos[None, None, :]
|
||||
img_ids = img_ids.reshape(-1, 3)
|
||||
|
||||
def __init__(self, theta: int = 10000, axes_dim=(8, 28, 28), scale_rope: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = list(axes_dim)
|
||||
self.scale_rope = scale_rope
|
||||
pos_index = torch.arange(4096)
|
||||
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
||||
# Plain attributes (not buffers): register_buffer strips complex imag.
|
||||
self.pos_freqs = torch.cat(
|
||||
[self._rope_params(pos_index, d, theta) for d in self.axes_dim], dim=1
|
||||
)
|
||||
self.neg_freqs = torch.cat(
|
||||
[self._rope_params(neg_index, d, theta) for d in self.axes_dim], dim=1
|
||||
)
|
||||
self.rope_cache: Dict[str, torch.Tensor] = {}
|
||||
# Text positions replicate across all 3 axes (matches original packing).
|
||||
txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float()
|
||||
txt_ids = txt_pos[:, None].expand(text_seq_len, 3)
|
||||
|
||||
@staticmethod
|
||||
def _rope_params(index: torch.Tensor, dim: int, theta: int = 10000) -> torch.Tensor:
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(
|
||||
index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).float().div(dim))
|
||||
)
|
||||
return torch.polar(torch.ones_like(freqs), freqs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_fhw: Union[List[Tuple[int, int, int]], Tuple[int, int, int]],
|
||||
txt_seq_lens: Union[List[int], int],
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
if not isinstance(video_fhw, list):
|
||||
video_fhw = [video_fhw]
|
||||
if not isinstance(txt_seq_lens, list):
|
||||
txt_seq_lens = [txt_seq_lens]
|
||||
assert len(video_fhw) == 1, "video_fhw must have length 1"
|
||||
|
||||
vid_freqs = []
|
||||
max_vid_index = 0
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
rope_key = f"{idx}_{height}_{width}_{device}"
|
||||
video_freq = self.rope_cache.get(rope_key)
|
||||
if video_freq is None:
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx=0).to(device)
|
||||
self.rope_cache[rope_key] = video_freq
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||
else:
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
vid_freqs.append(video_freq)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
return torch.cat(vid_freqs, dim=0), txt_freqs
|
||||
|
||||
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([d // 2 for d in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([d // 2 for d in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat(
|
||||
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
||||
).view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat(
|
||||
[freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0
|
||||
).view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
return torch.cat([img_ids, txt_ids], dim=0)
|
||||
|
||||
|
||||
class _TimestepEmbedder(nn.Module):
|
||||
@ -184,7 +129,7 @@ class LensJointAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
||||
freqs_cis: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -207,19 +152,6 @@ class LensJointAttention(nn.Module):
|
||||
txt_v = txt_v.contiguous()
|
||||
del txt_qkv
|
||||
|
||||
img_freqs, txt_freqs = image_rotary_emb
|
||||
if img_freqs.shape[0] < seq_img:
|
||||
raise ValueError(f"Image RoPE length {img_freqs.shape[0]} < {seq_img}")
|
||||
img_freqs = img_freqs[:seq_img]
|
||||
img_q = apply_rotary_emb_lens(img_q, img_freqs)
|
||||
img_k = apply_rotary_emb_lens(img_k, img_freqs)
|
||||
if seq_txt > 0:
|
||||
if txt_freqs.shape[0] < seq_txt:
|
||||
raise ValueError(f"Text RoPE length {txt_freqs.shape[0]} < {seq_txt}")
|
||||
txt_freqs = txt_freqs[:seq_txt]
|
||||
txt_q = apply_rotary_emb_lens(txt_q, txt_freqs)
|
||||
txt_k = apply_rotary_emb_lens(txt_k, txt_freqs)
|
||||
|
||||
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
|
||||
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
|
||||
del img_q, txt_q
|
||||
@ -228,6 +160,8 @@ class LensJointAttention(nn.Module):
|
||||
v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
|
||||
del img_v, txt_v
|
||||
|
||||
q, k = apply_rope(q, k, freqs_cis)
|
||||
|
||||
if attention_mask is not None:
|
||||
expected = (bsz, 1, 1, seq_img + seq_txt)
|
||||
if attention_mask.shape != expected:
|
||||
@ -308,7 +242,7 @@ class LensTransformerBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
||||
freqs_cis: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -321,7 +255,7 @@ class LensTransformerBlock(nn.Module):
|
||||
img_attn, txt_attn = self.attn(
|
||||
hidden_states=img_modulated,
|
||||
encoder_hidden_states=txt_modulated,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
freqs_cis=freqs_cis,
|
||||
attention_mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
@ -391,7 +325,7 @@ class LensTransformer2DModel(nn.Module):
|
||||
self.selected_layer_index = list(selected_layer_index)
|
||||
self.dtype = dtype
|
||||
|
||||
self.pos_embed = LensEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
||||
self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||
self.time_text_embed = LensTimestepProjEmbeddings(
|
||||
embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
@ -502,7 +436,8 @@ class LensTransformer2DModel(nn.Module):
|
||||
encoder_hidden_states = out["txt"]
|
||||
|
||||
temb = self.time_text_embed(timestep, hidden_states)
|
||||
image_rotary_emb = self.pos_embed([(1, h, w)], [text_seq_len], device=hidden_states.device)
|
||||
ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0)
|
||||
freqs_cis = self.pos_embed(ids)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
@ -515,7 +450,7 @@ class LensTransformer2DModel(nn.Module):
|
||||
hidden_states=args["img"],
|
||||
encoder_hidden_states=args["txt"],
|
||||
temb=args["vec"],
|
||||
image_rotary_emb=args["pe"],
|
||||
freqs_cis=args["pe"],
|
||||
attention_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"),
|
||||
)
|
||||
@ -525,7 +460,7 @@ class LensTransformer2DModel(nn.Module):
|
||||
"img": hidden_states,
|
||||
"txt": encoder_hidden_states,
|
||||
"vec": temb,
|
||||
"pe": image_rotary_emb,
|
||||
"pe": freqs_cis,
|
||||
"attn_mask": joint_mask,
|
||||
"transformer_options": transformer_options,
|
||||
},
|
||||
@ -538,7 +473,7 @@ class LensTransformer2DModel(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
freqs_cis=freqs_cis,
|
||||
attention_mask=joint_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user