From 4cd881866bad0cde70273cc123d725693c1f2759 Mon Sep 17 00:00:00 2001 From: contentis Date: Wed, 5 Nov 2025 02:10:11 +0100 Subject: [PATCH] Use single apply_rope function across models (#10547) --- comfy/ldm/flux/layers.py | 4 +- comfy/ldm/flux/math.py | 10 +--- comfy/ldm/lightricks/model.py | 88 ++++++++++++++--------------------- comfy/ldm/qwen_image/model.py | 36 +++++++------- comfy/ldm/wan/model.py | 1 + 5 files changed, 59 insertions(+), 80 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index ef21b416b..a3eab0470 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -195,8 +195,8 @@ class DoubleStreamBlock(nn.Module): txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] # calculate the img bloks - img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) - img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img) + img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) + img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img) # calculate the txt bloks txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 8deda0d4a..158420290 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -7,15 +7,7 @@ import comfy.model_management def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: - q_shape = q.shape - k_shape = k.shape - - if pe is not None: - q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2) - k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2) - q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) - k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) - + q, k = apply_rope(q, k, pe) heads = q.shape[1] x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index def365ba7..5bcba998b 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -3,12 +3,11 @@ from torch import nn import comfy.patcher_extension import comfy.ldm.modules.attention import comfy.ldm.common_dit -from einops import rearrange import math from typing import Dict, Optional, Tuple from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords - +from comfy.ldm.flux.math import apply_rope1 def get_timestep_embedding( timesteps: torch.Tensor, @@ -238,20 +237,6 @@ class FeedForward(nn.Module): return self.net(x) -def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one - cos_freqs = freqs_cis[0] - sin_freqs = freqs_cis[1] - - t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) - t1, t2 = t_dup.unbind(dim=-1) - t_dup = torch.stack((-t2, t1), dim=-1) - input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") - - out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs - - return out - - class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None): super().__init__() @@ -281,8 +266,8 @@ class CrossAttention(nn.Module): k = self.k_norm(k) if pe is not None: - q = apply_rotary_emb(q, pe) - k = apply_rotary_emb(k, pe) + q = apply_rope1(q.unsqueeze(1), pe).squeeze(1) + k = apply_rope1(k.unsqueeze(1), pe).squeeze(1) if mask is None: out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) @@ -306,12 +291,17 @@ class BasicTransformerBlock(nn.Module): def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) - x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa + norm_x = comfy.ldm.common_dit.rms_norm(x) + attn1_input = torch.addcmul(norm_x, norm_x, scale_msa).add_(shift_msa) + attn1_result = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) + x.addcmul_(attn1_result, gate_msa) x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) - y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp - x += self.ff(y) * gate_mlp + norm_x = comfy.ldm.common_dit.rms_norm(x) + y = torch.addcmul(norm_x, norm_x, scale_mlp).add_(shift_mlp) + ff_result = self.ff(y) + x.addcmul_(ff_result, gate_mlp) return x @@ -327,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos): def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]): - dtype = torch.float32 #self.dtype + dtype = torch.float32 + device = indices_grid.device + # Get fractional positions and compute frequency indices fractional_positions = get_fractional_positions(indices_grid, max_pos) + indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2 - start = 1 - end = theta - device = fractional_positions.device + # Compute frequencies and apply cos/sin + freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) + cos_vals = freqs.cos().repeat_interleave(2, dim=-1) + sin_vals = freqs.sin().repeat_interleave(2, dim=-1) - indices = theta ** ( - torch.linspace( - math.log(start, theta), - math.log(end, theta), - dim // 6, - device=device, - dtype=dtype, - ) - ) - indices = indices.to(dtype=dtype) - - indices = indices * math.pi / 2 - - freqs = ( - (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) - .transpose(-1, -2) - .flatten(2) - ) - - cos_freq = freqs.cos().repeat_interleave(2, dim=-1) - sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + # Pad if dim is not divisible by 6 if dim % 6 != 0: - cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) - sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) - cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) - sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) - return cos_freq.to(out_dtype), sin_freq.to(out_dtype) + padding_size = dim % 6 + cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1) + sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1) + + # Reshape and extract one value per pair (since repeat_interleave duplicates each value) + cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2] + sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2] + + # Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension + freqs_cis = torch.stack([ + torch.stack([cos_vals, -sin_vals], dim=-1), + torch.stack([sin_vals, cos_vals], dim=-1) + ], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2] + + return freqs_cis.to(out_dtype) class LTXVModel(torch.nn.Module): @@ -501,7 +485,7 @@ class LTXVModel(torch.nn.Module): shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] x = self.norm_out(x) # Modulation - x = x * (1 + scale) + shift + x = torch.addcmul(x, x, scale).add_(shift) x = self.proj_out(x) x = self.patchifier.unpatchify( diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index b9f60c2b7..81d3ee7c0 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -10,6 +10,7 @@ from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND import comfy.ldm.common_dit import comfy.patcher_extension +from comfy.ldm.flux.math import apply_rope1 class GELU(nn.Module): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): @@ -134,33 +135,34 @@ class Attention(nn.Module): image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = hidden_states.shape[0] + seq_img = hidden_states.shape[1] seq_txt = encoder_hidden_states.shape[1] - img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1)) - img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1)) - img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1)) + # Project and reshape to BHND format (batch, heads, seq, dim) + img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() + img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() + img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2) - txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) - txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) - txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) + txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() + txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() + txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2) img_query = self.norm_q(img_query) img_key = self.norm_k(img_key) txt_query = self.norm_added_q(txt_query) txt_key = self.norm_added_k(txt_key) - joint_query = torch.cat([txt_query, img_query], dim=1) - joint_key = torch.cat([txt_key, img_key], dim=1) - joint_value = torch.cat([txt_value, img_value], dim=1) + joint_query = torch.cat([txt_query, img_query], dim=2) + joint_key = torch.cat([txt_key, img_key], dim=2) + joint_value = torch.cat([txt_value, img_value], dim=2) - joint_query = apply_rotary_emb(joint_query, image_rotary_emb) - joint_key = apply_rotary_emb(joint_key, image_rotary_emb) + joint_query = apply_rope1(joint_query, image_rotary_emb) + joint_key = apply_rope1(joint_key, image_rotary_emb) - joint_query = joint_query.flatten(start_dim=2) - joint_key = joint_key.flatten(start_dim=2) - joint_value = joint_value.flatten(start_dim=2) - - joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, + attention_mask, transformer_options=transformer_options, + skip_reshape=True) txt_attn_output = joint_hidden_states[:, :seq_txt, :] img_attn_output = joint_hidden_states[:, seq_txt:, :] @@ -413,7 +415,7 @@ class QwenImageTransformer2DModel(nn.Module): txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + image_rotary_emb = self.pe_embedder(ids).to(torch.float32).contiguous() del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 5ec1511ce..a9d5e10d9 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -232,6 +232,7 @@ class WanAttentionBlock(nn.Module): # assert e[0].dtype == torch.float32 # self-attention + x = x.contiguous() # otherwise implicit in LayerNorm y = self.self_attn( torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), freqs, transformer_options=transformer_options)