mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 09:57:24 +08:00
Use Flux RoPE
This commit is contained in:
parent
2aba5bafca
commit
b429028ad2
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -10,6 +10,8 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
import comfy.ldm.flux.layers
|
import comfy.ldm.flux.layers
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
from comfy.ldm.flux.math import apply_rope
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
|
||||||
@ -17,96 +19,39 @@ def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor:
|
|||||||
return comfy.ldm.flux.layers.timestep_embedding(t, dim)
|
return comfy.ldm.flux.layers.timestep_embedding(t, dim)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb_lens(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
def _lens_position_ids(
|
||||||
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
frame: int, height: int, width: int, text_seq_len: int,
|
||||||
freqs_cis = freqs_cis.unsqueeze(1)
|
scale_rope: bool = True, device=None,
|
||||||
x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3)
|
) -> torch.Tensor:
|
||||||
return x_out.type_as(x)
|
"""Lens axial (frame, h, w) position ids for joint image + text sequence.
|
||||||
|
|
||||||
|
With ``scale_rope=True`` h/w are centered around 0 (negative + positive
|
||||||
|
halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``;
|
||||||
|
caller adds a batch dim for ``EmbedND``.
|
||||||
|
"""
|
||||||
|
if scale_rope:
|
||||||
|
h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device),
|
||||||
|
torch.arange(0, height // 2, device=device)])
|
||||||
|
w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device),
|
||||||
|
torch.arange(0, width // 2, device=device)])
|
||||||
|
text_start = max(height // 2, width // 2)
|
||||||
|
else:
|
||||||
|
h_pos = torch.arange(height, device=device)
|
||||||
|
w_pos = torch.arange(width, device=device)
|
||||||
|
text_start = max(height, width)
|
||||||
|
|
||||||
class LensEmbedRope(nn.Module):
|
f_pos = torch.arange(frame, device=device)
|
||||||
"""Frame/H/W axial RoPE shared between image and text streams."""
|
img_ids = torch.zeros(frame, height, width, 3, device=device)
|
||||||
|
img_ids[..., 0] = f_pos[:, None, None]
|
||||||
|
img_ids[..., 1] = h_pos[None, :, None]
|
||||||
|
img_ids[..., 2] = w_pos[None, None, :]
|
||||||
|
img_ids = img_ids.reshape(-1, 3)
|
||||||
|
|
||||||
def __init__(self, theta: int = 10000, axes_dim=(8, 28, 28), scale_rope: bool = True) -> None:
|
# Text positions replicate across all 3 axes (matches original packing).
|
||||||
super().__init__()
|
txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float()
|
||||||
self.theta = theta
|
txt_ids = txt_pos[:, None].expand(text_seq_len, 3)
|
||||||
self.axes_dim = list(axes_dim)
|
|
||||||
self.scale_rope = scale_rope
|
|
||||||
pos_index = torch.arange(4096)
|
|
||||||
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
|
||||||
# Plain attributes (not buffers): register_buffer strips complex imag.
|
|
||||||
self.pos_freqs = torch.cat(
|
|
||||||
[self._rope_params(pos_index, d, theta) for d in self.axes_dim], dim=1
|
|
||||||
)
|
|
||||||
self.neg_freqs = torch.cat(
|
|
||||||
[self._rope_params(neg_index, d, theta) for d in self.axes_dim], dim=1
|
|
||||||
)
|
|
||||||
self.rope_cache: Dict[str, torch.Tensor] = {}
|
|
||||||
|
|
||||||
@staticmethod
|
return torch.cat([img_ids, txt_ids], dim=0)
|
||||||
def _rope_params(index: torch.Tensor, dim: int, theta: int = 10000) -> torch.Tensor:
|
|
||||||
assert dim % 2 == 0
|
|
||||||
freqs = torch.outer(
|
|
||||||
index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).float().div(dim))
|
|
||||||
)
|
|
||||||
return torch.polar(torch.ones_like(freqs), freqs)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
video_fhw: Union[List[Tuple[int, int, int]], Tuple[int, int, int]],
|
|
||||||
txt_seq_lens: Union[List[int], int],
|
|
||||||
device: torch.device,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
if self.pos_freqs.device != device:
|
|
||||||
self.pos_freqs = self.pos_freqs.to(device)
|
|
||||||
self.neg_freqs = self.neg_freqs.to(device)
|
|
||||||
|
|
||||||
if isinstance(video_fhw, list):
|
|
||||||
video_fhw = video_fhw[0]
|
|
||||||
if not isinstance(video_fhw, list):
|
|
||||||
video_fhw = [video_fhw]
|
|
||||||
if not isinstance(txt_seq_lens, list):
|
|
||||||
txt_seq_lens = [txt_seq_lens]
|
|
||||||
assert len(video_fhw) == 1, "video_fhw must have length 1"
|
|
||||||
|
|
||||||
vid_freqs = []
|
|
||||||
max_vid_index = 0
|
|
||||||
for idx, fhw in enumerate(video_fhw):
|
|
||||||
frame, height, width = fhw
|
|
||||||
rope_key = f"{idx}_{height}_{width}_{device}"
|
|
||||||
video_freq = self.rope_cache.get(rope_key)
|
|
||||||
if video_freq is None:
|
|
||||||
video_freq = self._compute_video_freqs(frame, height, width, idx=0).to(device)
|
|
||||||
self.rope_cache[rope_key] = video_freq
|
|
||||||
if self.scale_rope:
|
|
||||||
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
|
||||||
else:
|
|
||||||
max_vid_index = max(height, width, max_vid_index)
|
|
||||||
vid_freqs.append(video_freq)
|
|
||||||
|
|
||||||
max_len = max(txt_seq_lens)
|
|
||||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
|
||||||
return torch.cat(vid_freqs, dim=0), txt_freqs
|
|
||||||
|
|
||||||
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
|
|
||||||
seq_lens = frame * height * width
|
|
||||||
freqs_pos = self.pos_freqs.split([d // 2 for d in self.axes_dim], dim=1)
|
|
||||||
freqs_neg = self.neg_freqs.split([d // 2 for d in self.axes_dim], dim=1)
|
|
||||||
|
|
||||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
|
||||||
if self.scale_rope:
|
|
||||||
freqs_height = torch.cat(
|
|
||||||
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
|
||||||
).view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
||||||
freqs_width = torch.cat(
|
|
||||||
[freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0
|
|
||||||
).view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
||||||
else:
|
|
||||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
||||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
||||||
|
|
||||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
|
||||||
return freqs.clone().contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
class _TimestepEmbedder(nn.Module):
|
class _TimestepEmbedder(nn.Module):
|
||||||
@ -184,7 +129,7 @@ class LensJointAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
freqs_cis: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
transformer_options: Optional[Dict[str, Any]] = None,
|
transformer_options: Optional[Dict[str, Any]] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -207,19 +152,6 @@ class LensJointAttention(nn.Module):
|
|||||||
txt_v = txt_v.contiguous()
|
txt_v = txt_v.contiguous()
|
||||||
del txt_qkv
|
del txt_qkv
|
||||||
|
|
||||||
img_freqs, txt_freqs = image_rotary_emb
|
|
||||||
if img_freqs.shape[0] < seq_img:
|
|
||||||
raise ValueError(f"Image RoPE length {img_freqs.shape[0]} < {seq_img}")
|
|
||||||
img_freqs = img_freqs[:seq_img]
|
|
||||||
img_q = apply_rotary_emb_lens(img_q, img_freqs)
|
|
||||||
img_k = apply_rotary_emb_lens(img_k, img_freqs)
|
|
||||||
if seq_txt > 0:
|
|
||||||
if txt_freqs.shape[0] < seq_txt:
|
|
||||||
raise ValueError(f"Text RoPE length {txt_freqs.shape[0]} < {seq_txt}")
|
|
||||||
txt_freqs = txt_freqs[:seq_txt]
|
|
||||||
txt_q = apply_rotary_emb_lens(txt_q, txt_freqs)
|
|
||||||
txt_k = apply_rotary_emb_lens(txt_k, txt_freqs)
|
|
||||||
|
|
||||||
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
|
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
|
||||||
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
|
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
|
||||||
del img_q, txt_q
|
del img_q, txt_q
|
||||||
@ -228,6 +160,8 @@ class LensJointAttention(nn.Module):
|
|||||||
v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
|
v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
|
||||||
del img_v, txt_v
|
del img_v, txt_v
|
||||||
|
|
||||||
|
q, k = apply_rope(q, k, freqs_cis)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
expected = (bsz, 1, 1, seq_img + seq_txt)
|
expected = (bsz, 1, 1, seq_img + seq_txt)
|
||||||
if attention_mask.shape != expected:
|
if attention_mask.shape != expected:
|
||||||
@ -308,7 +242,7 @@ class LensTransformerBlock(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
temb: torch.Tensor,
|
temb: torch.Tensor,
|
||||||
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor],
|
freqs_cis: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
transformer_options: Optional[Dict[str, Any]] = None,
|
transformer_options: Optional[Dict[str, Any]] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -321,7 +255,7 @@ class LensTransformerBlock(nn.Module):
|
|||||||
img_attn, txt_attn = self.attn(
|
img_attn, txt_attn = self.attn(
|
||||||
hidden_states=img_modulated,
|
hidden_states=img_modulated,
|
||||||
encoder_hidden_states=txt_modulated,
|
encoder_hidden_states=txt_modulated,
|
||||||
image_rotary_emb=image_rotary_emb,
|
freqs_cis=freqs_cis,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
@ -391,7 +325,7 @@ class LensTransformer2DModel(nn.Module):
|
|||||||
self.selected_layer_index = list(selected_layer_index)
|
self.selected_layer_index = list(selected_layer_index)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
self.pos_embed = LensEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||||
self.time_text_embed = LensTimestepProjEmbeddings(
|
self.time_text_embed = LensTimestepProjEmbeddings(
|
||||||
embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
|
embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
@ -502,7 +436,8 @@ class LensTransformer2DModel(nn.Module):
|
|||||||
encoder_hidden_states = out["txt"]
|
encoder_hidden_states = out["txt"]
|
||||||
|
|
||||||
temb = self.time_text_embed(timestep, hidden_states)
|
temb = self.time_text_embed(timestep, hidden_states)
|
||||||
image_rotary_emb = self.pos_embed([(1, h, w)], [text_seq_len], device=hidden_states.device)
|
ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0)
|
||||||
|
freqs_cis = self.pos_embed(ids)
|
||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||||
transformer_options["block_type"] = "double"
|
transformer_options["block_type"] = "double"
|
||||||
@ -515,7 +450,7 @@ class LensTransformer2DModel(nn.Module):
|
|||||||
hidden_states=args["img"],
|
hidden_states=args["img"],
|
||||||
encoder_hidden_states=args["txt"],
|
encoder_hidden_states=args["txt"],
|
||||||
temb=args["vec"],
|
temb=args["vec"],
|
||||||
image_rotary_emb=args["pe"],
|
freqs_cis=args["pe"],
|
||||||
attention_mask=args.get("attn_mask"),
|
attention_mask=args.get("attn_mask"),
|
||||||
transformer_options=args.get("transformer_options"),
|
transformer_options=args.get("transformer_options"),
|
||||||
)
|
)
|
||||||
@ -525,7 +460,7 @@ class LensTransformer2DModel(nn.Module):
|
|||||||
"img": hidden_states,
|
"img": hidden_states,
|
||||||
"txt": encoder_hidden_states,
|
"txt": encoder_hidden_states,
|
||||||
"vec": temb,
|
"vec": temb,
|
||||||
"pe": image_rotary_emb,
|
"pe": freqs_cis,
|
||||||
"attn_mask": joint_mask,
|
"attn_mask": joint_mask,
|
||||||
"transformer_options": transformer_options,
|
"transformer_options": transformer_options,
|
||||||
},
|
},
|
||||||
@ -538,7 +473,7 @@ class LensTransformer2DModel(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
image_rotary_emb=image_rotary_emb,
|
freqs_cis=freqs_cis,
|
||||||
attention_mask=joint_mask,
|
attention_mask=joint_mask,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user