mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
Use single apply_rope function across models (#10547)
This commit is contained in:
parent
265adad858
commit
4cd881866b
@ -195,8 +195,8 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||
|
||||
@ -7,15 +7,7 @@ import comfy.model_management
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
if pe is not None:
|
||||
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
q, k = apply_rope(q, k, pe)
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
@ -3,12 +3,11 @@ from torch import nn
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.modules.attention
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
@ -238,20 +237,6 @@ class FeedForward(nn.Module):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
||||
cos_freqs = freqs_cis[0]
|
||||
sin_freqs = freqs_cis[1]
|
||||
|
||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||
|
||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
@ -281,8 +266,8 @@ class CrossAttention(nn.Module):
|
||||
k = self.k_norm(k)
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe)
|
||||
k = apply_rotary_emb(k, pe)
|
||||
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
|
||||
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
@ -306,12 +291,17 @@ class BasicTransformerBlock(nn.Module):
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
|
||||
norm_x = comfy.ldm.common_dit.rms_norm(x)
|
||||
attn1_input = torch.addcmul(norm_x, norm_x, scale_msa).add_(shift_msa)
|
||||
attn1_result = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
||||
x.addcmul_(attn1_result, gate_msa)
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
norm_x = comfy.ldm.common_dit.rms_norm(x)
|
||||
y = torch.addcmul(norm_x, norm_x, scale_mlp).add_(shift_mlp)
|
||||
ff_result = self.ff(y)
|
||||
x.addcmul_(ff_result, gate_mlp)
|
||||
|
||||
return x
|
||||
|
||||
@ -327,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos):
|
||||
|
||||
|
||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||
dtype = torch.float32 #self.dtype
|
||||
dtype = torch.float32
|
||||
device = indices_grid.device
|
||||
|
||||
# Get fractional positions and compute frequency indices
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
|
||||
|
||||
start = 1
|
||||
end = theta
|
||||
device = fractional_positions.device
|
||||
# Compute frequencies and apply cos/sin
|
||||
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
|
||||
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
|
||||
indices = theta ** (
|
||||
torch.linspace(
|
||||
math.log(start, theta),
|
||||
math.log(end, theta),
|
||||
dim // 6,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
indices = indices.to(dtype=dtype)
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
|
||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
# Pad if dim is not divisible by 6
|
||||
if dim % 6 != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
||||
padding_size = dim % 6
|
||||
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
|
||||
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
|
||||
|
||||
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
|
||||
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2]
|
||||
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2]
|
||||
|
||||
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
|
||||
freqs_cis = torch.stack([
|
||||
torch.stack([cos_vals, -sin_vals], dim=-1),
|
||||
torch.stack([sin_vals, cos_vals], dim=-1)
|
||||
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
|
||||
|
||||
return freqs_cis.to(out_dtype)
|
||||
|
||||
|
||||
class LTXVModel(torch.nn.Module):
|
||||
@ -501,7 +485,7 @@ class LTXVModel(torch.nn.Module):
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
x = self.norm_out(x)
|
||||
# Modulation
|
||||
x = x * (1 + scale) + shift
|
||||
x = torch.addcmul(x, x, scale).add_(shift)
|
||||
x = self.proj_out(x)
|
||||
|
||||
x = self.patchifier.unpatchify(
|
||||
|
||||
@ -10,6 +10,7 @@ from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||
@ -134,33 +135,34 @@ class Attention(nn.Module):
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.shape[0]
|
||||
seq_img = hidden_states.shape[1]
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
# Project and reshape to BHND format (batch, heads, seq, dim)
|
||||
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
|
||||
|
||||
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
|
||||
|
||||
img_query = self.norm_q(img_query)
|
||||
img_key = self.norm_k(img_key)
|
||||
txt_query = self.norm_added_q(txt_query)
|
||||
txt_key = self.norm_added_k(txt_key)
|
||||
|
||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
joint_query = torch.cat([txt_query, img_query], dim=2)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=2)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=2)
|
||||
|
||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||
|
||||
joint_query = joint_query.flatten(start_dim=2)
|
||||
joint_key = joint_key.flatten(start_dim=2)
|
||||
joint_value = joint_value.flatten(start_dim=2)
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||
attention_mask, transformer_options=transformer_options,
|
||||
skip_reshape=True)
|
||||
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||
@ -413,7 +415,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(torch.float32).contiguous()
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
|
||||
@ -232,6 +232,7 @@ class WanAttentionBlock(nn.Module):
|
||||
# assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
x = x.contiguous() # otherwise implicit in LayerNorm
|
||||
y = self.self_attn(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs, transformer_options=transformer_options)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user