mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
Use single apply_rope function across models (#10547)
This commit is contained in:
parent
265adad858
commit
4cd881866b
@ -195,8 +195,8 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||||
|
|||||||
@ -7,15 +7,7 @@ import comfy.model_management
|
|||||||
|
|
||||||
|
|
||||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||||
q_shape = q.shape
|
q, k = apply_rope(q, k, pe)
|
||||||
k_shape = k.shape
|
|
||||||
|
|
||||||
if pe is not None:
|
|
||||||
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
|
||||||
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
|
||||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
|
||||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
|
||||||
|
|
||||||
heads = q.shape[1]
|
heads = q.shape[1]
|
||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -3,12 +3,11 @@ from torch import nn
|
|||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.ldm.modules.attention
|
import comfy.ldm.modules.attention
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
from einops import rearrange
|
|
||||||
import math
|
import math
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||||
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
|
|
||||||
def get_timestep_embedding(
|
def get_timestep_embedding(
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
@ -238,20 +237,6 @@ class FeedForward(nn.Module):
|
|||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
|
||||||
cos_freqs = freqs_cis[0]
|
|
||||||
sin_freqs = freqs_cis[1]
|
|
||||||
|
|
||||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
|
||||||
t1, t2 = t_dup.unbind(dim=-1)
|
|
||||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
|
||||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
|
||||||
|
|
||||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -281,8 +266,8 @@ class CrossAttention(nn.Module):
|
|||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
|
|
||||||
if pe is not None:
|
if pe is not None:
|
||||||
q = apply_rotary_emb(q, pe)
|
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
|
||||||
k = apply_rotary_emb(k, pe)
|
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
@ -306,12 +291,17 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||||
|
|
||||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
|
norm_x = comfy.ldm.common_dit.rms_norm(x)
|
||||||
|
attn1_input = torch.addcmul(norm_x, norm_x, scale_msa).add_(shift_msa)
|
||||||
|
attn1_result = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
||||||
|
x.addcmul_(attn1_result, gate_msa)
|
||||||
|
|
||||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
norm_x = comfy.ldm.common_dit.rms_norm(x)
|
||||||
x += self.ff(y) * gate_mlp
|
y = torch.addcmul(norm_x, norm_x, scale_mlp).add_(shift_mlp)
|
||||||
|
ff_result = self.ff(y)
|
||||||
|
x.addcmul_(ff_result, gate_mlp)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -327,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos):
|
|||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||||
dtype = torch.float32 #self.dtype
|
dtype = torch.float32
|
||||||
|
device = indices_grid.device
|
||||||
|
|
||||||
|
# Get fractional positions and compute frequency indices
|
||||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||||
|
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
|
||||||
|
|
||||||
start = 1
|
# Compute frequencies and apply cos/sin
|
||||||
end = theta
|
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
|
||||||
device = fractional_positions.device
|
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
|
||||||
|
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
|
||||||
|
|
||||||
indices = theta ** (
|
# Pad if dim is not divisible by 6
|
||||||
torch.linspace(
|
|
||||||
math.log(start, theta),
|
|
||||||
math.log(end, theta),
|
|
||||||
dim // 6,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
indices = indices.to(dtype=dtype)
|
|
||||||
|
|
||||||
indices = indices * math.pi / 2
|
|
||||||
|
|
||||||
freqs = (
|
|
||||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
|
||||||
.transpose(-1, -2)
|
|
||||||
.flatten(2)
|
|
||||||
)
|
|
||||||
|
|
||||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
|
||||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
|
||||||
if dim % 6 != 0:
|
if dim % 6 != 0:
|
||||||
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
padding_size = dim % 6
|
||||||
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
|
||||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
|
||||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
|
||||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
|
||||||
|
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2]
|
||||||
|
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2]
|
||||||
|
|
||||||
|
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
|
||||||
|
freqs_cis = torch.stack([
|
||||||
|
torch.stack([cos_vals, -sin_vals], dim=-1),
|
||||||
|
torch.stack([sin_vals, cos_vals], dim=-1)
|
||||||
|
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
|
||||||
|
|
||||||
|
return freqs_cis.to(out_dtype)
|
||||||
|
|
||||||
|
|
||||||
class LTXVModel(torch.nn.Module):
|
class LTXVModel(torch.nn.Module):
|
||||||
@ -501,7 +485,7 @@ class LTXVModel(torch.nn.Module):
|
|||||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||||
x = self.norm_out(x)
|
x = self.norm_out(x)
|
||||||
# Modulation
|
# Modulation
|
||||||
x = x * (1 + scale) + shift
|
x = torch.addcmul(x, x, scale).add_(shift)
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
|
|
||||||
x = self.patchifier.unpatchify(
|
x = self.patchifier.unpatchify(
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from comfy.ldm.modules.attention import optimized_attention_masked
|
|||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
|
|
||||||
class GELU(nn.Module):
|
class GELU(nn.Module):
|
||||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||||
@ -134,33 +135,34 @@ class Attention(nn.Module):
|
|||||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
seq_img = hidden_states.shape[1]
|
||||||
seq_txt = encoder_hidden_states.shape[1]
|
seq_txt = encoder_hidden_states.shape[1]
|
||||||
|
|
||||||
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
|
# Project and reshape to BHND format (batch, heads, seq, dim)
|
||||||
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
|
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
|
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
|
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
|
||||||
|
|
||||||
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
|
||||||
|
|
||||||
img_query = self.norm_q(img_query)
|
img_query = self.norm_q(img_query)
|
||||||
img_key = self.norm_k(img_key)
|
img_key = self.norm_k(img_key)
|
||||||
txt_query = self.norm_added_q(txt_query)
|
txt_query = self.norm_added_q(txt_query)
|
||||||
txt_key = self.norm_added_k(txt_key)
|
txt_key = self.norm_added_k(txt_key)
|
||||||
|
|
||||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
joint_query = torch.cat([txt_query, img_query], dim=2)
|
||||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
joint_key = torch.cat([txt_key, img_key], dim=2)
|
||||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
joint_value = torch.cat([txt_value, img_value], dim=2)
|
||||||
|
|
||||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||||
|
|
||||||
joint_query = joint_query.flatten(start_dim=2)
|
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||||
joint_key = joint_key.flatten(start_dim=2)
|
attention_mask, transformer_options=transformer_options,
|
||||||
joint_value = joint_value.flatten(start_dim=2)
|
skip_reshape=True)
|
||||||
|
|
||||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
|
|
||||||
|
|
||||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||||
@ -413,7 +415,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
image_rotary_emb = self.pe_embedder(ids).to(torch.float32).contiguous()
|
||||||
del ids, txt_ids, img_ids
|
del ids, txt_ids, img_ids
|
||||||
|
|
||||||
hidden_states = self.img_in(hidden_states)
|
hidden_states = self.img_in(hidden_states)
|
||||||
|
|||||||
@ -232,6 +232,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
# assert e[0].dtype == torch.float32
|
# assert e[0].dtype == torch.float32
|
||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
|
x = x.contiguous() # otherwise implicit in LayerNorm
|
||||||
y = self.self_attn(
|
y = self.self_attn(
|
||||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||||
freqs, transformer_options=transformer_options)
|
freqs, transformer_options=transformer_options)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user