Add files via upload

This commit is contained in:
Icbears 2026-02-19 09:54:48 +08:00 committed by GitHub
parent 6d11cc7354
commit c5549de6f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,6 +12,7 @@ import comfy.ldm.common_dit
import comfy.patcher_extension
from comfy.ldm.flux.math import apply_rope1
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
super().__init__()
@ -150,35 +151,73 @@ class Attention(nn.Module):
seq_txt = encoder_hidden_states.shape[1]
# Project and reshape to BHND format (batch, heads, seq, dim)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
if hidden_states.dtype not in (torch.bfloat16,):
# fp16 optimization: IMG_K, IMG_V, TXT_Q, TXT_V use fp16; IMG_Q and TXT_K stay original dtype
img_h = hidden_states.half()
txt_h = encoder_hidden_states.half()
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(img_h).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_value = self.to_v(img_h).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
txt_query = self.add_q_proj(txt_h).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_value = self.add_v_proj(txt_h).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
joint_query = torch.cat([txt_query, img_query], dim=2)
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)
del img_h, txt_h
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
if encoder_hidden_states_mask is not None:
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
joint_query = torch.cat([txt_query, img_query], dim=2)
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
if encoder_hidden_states_mask is not None:
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
else:
attn_mask = None
joint_hidden_states = optimized_attention_masked(joint_query.half(), joint_key.half(), joint_value.half(), self.heads,
attn_mask.half() if attn_mask is not None else None, transformer_options=transformer_options,
skip_reshape=True).float()
else:
attn_mask = None
# bfloat16: use original dtype throughout
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attn_mask, transformer_options=transformer_options,
skip_reshape=True)
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
joint_query = torch.cat([txt_query, img_query], dim=2)
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
if encoder_hidden_states_mask is not None:
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
else:
attn_mask = None
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attn_mask, transformer_options=transformer_options,
skip_reshape=True)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]