From c5549de6f9d87a199c99481dbcb78957c6548035 Mon Sep 17 00:00:00 2001 From: Icbears Date: Thu, 19 Feb 2026 09:54:48 +0800 Subject: [PATCH] Add files via upload --- comfy/ldm/qwen_image/model.py | 83 +++++++++++++++++++++++++---------- 1 file changed, 61 insertions(+), 22 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 6eb744286..3ff20c81d 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -12,6 +12,7 @@ import comfy.ldm.common_dit import comfy.patcher_extension from comfy.ldm.flux.math import apply_rope1 + class GELU(nn.Module): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): super().__init__() @@ -150,35 +151,73 @@ class Attention(nn.Module): seq_txt = encoder_hidden_states.shape[1] # Project and reshape to BHND format (batch, heads, seq, dim) - img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() - img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() - img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2) + if hidden_states.dtype not in (torch.bfloat16,): + # fp16 optimization: IMG_K, IMG_V, TXT_Q, TXT_V use fp16; IMG_Q and TXT_K stay original dtype + img_h = hidden_states.half() + txt_h = encoder_hidden_states.half() - txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() - txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() - txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2) + img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() + img_key = self.to_k(img_h).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() + img_value = self.to_v(img_h).view(batch_size, seq_img, self.heads, -1).transpose(1, 2) - img_query = self.norm_q(img_query) - img_key = self.norm_k(img_key) - txt_query = self.norm_added_q(txt_query) - txt_key = self.norm_added_k(txt_key) + txt_query = self.add_q_proj(txt_h).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() + txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() + txt_value = self.add_v_proj(txt_h).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2) - joint_query = torch.cat([txt_query, img_query], dim=2) - joint_key = torch.cat([txt_key, img_key], dim=2) - joint_value = torch.cat([txt_value, img_value], dim=2) + del img_h, txt_h - joint_query = apply_rope1(joint_query, image_rotary_emb) - joint_key = apply_rope1(joint_key, image_rotary_emb) + img_query = self.norm_q(img_query) + img_key = self.norm_k(img_key) + txt_query = self.norm_added_q(txt_query) + txt_key = self.norm_added_k(txt_key) - if encoder_hidden_states_mask is not None: - attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device) - attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask + joint_query = torch.cat([txt_query, img_query], dim=2) + joint_key = torch.cat([txt_key, img_key], dim=2) + joint_value = torch.cat([txt_value, img_value], dim=2) + + joint_query = apply_rope1(joint_query, image_rotary_emb) + joint_key = apply_rope1(joint_key, image_rotary_emb) + + if encoder_hidden_states_mask is not None: + attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device) + attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask + else: + attn_mask = None + + joint_hidden_states = optimized_attention_masked(joint_query.half(), joint_key.half(), joint_value.half(), self.heads, + attn_mask.half() if attn_mask is not None else None, transformer_options=transformer_options, + skip_reshape=True).float() else: - attn_mask = None + # bfloat16: use original dtype throughout + img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() + img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() + img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2) - joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, - attn_mask, transformer_options=transformer_options, - skip_reshape=True) + txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() + txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() + txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2) + + img_query = self.norm_q(img_query) + img_key = self.norm_k(img_key) + txt_query = self.norm_added_q(txt_query) + txt_key = self.norm_added_k(txt_key) + + joint_query = torch.cat([txt_query, img_query], dim=2) + joint_key = torch.cat([txt_key, img_key], dim=2) + joint_value = torch.cat([txt_value, img_value], dim=2) + + joint_query = apply_rope1(joint_query, image_rotary_emb) + joint_key = apply_rope1(joint_key, image_rotary_emb) + + if encoder_hidden_states_mask is not None: + attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device) + attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask + else: + attn_mask = None + + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, + attn_mask, transformer_options=transformer_options, + skip_reshape=True) txt_attn_output = joint_hidden_states[:, :seq_txt, :] img_attn_output = joint_hidden_states[:, seq_txt:, :]