added attention_mask to QwenImageTransformerBlock.forward()

This commit is contained in:
nolan4 2025-11-04 19:55:24 -08:00
parent 16adfe2153
commit 9792606847

View File

@ -159,18 +159,12 @@ class Attention(nn.Module):
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
# Apply EliGen attention mask if present
effective_mask = attention_mask
if transformer_options is not None:
eligen_mask = transformer_options.get("eligen_attention_mask", None)
if eligen_mask is not None:
effective_mask = eligen_mask
# Validate shape
# Validate attention mask shape if provided
if attention_mask is not None:
expected_seq = joint_query.shape[1]
if eligen_mask.shape[-1] != expected_seq:
if attention_mask.shape[-1] != expected_seq:
raise ValueError(
f"EliGen attention mask shape mismatch: {eligen_mask.shape} "
f"Attention mask shape mismatch: {attention_mask.shape} "
f"doesn't match sequence length {expected_seq}"
)
@ -179,7 +173,7 @@ class Attention(nn.Module):
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, effective_mask, transformer_options=transformer_options)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
@ -246,6 +240,7 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb)
@ -262,6 +257,7 @@ class QwenImageTransformerBlock(nn.Module):
hidden_states=img_modulated,
encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
@ -640,6 +636,9 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
# Initialize attention mask (None for standard generation)
eligen_attention_mask = None
# Extract EliGen entity data
entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
@ -659,8 +658,8 @@ class QwenImageTransformer2DModel(nn.Module):
if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond:
# EliGen path
height = int(orig_shape[-2] * 8)
width = int(orig_shape[-1] * 8)
height = int(orig_shape[-2] * self.LATENT_TO_PIXEL_RATIO)
width = int(orig_shape[-1] * self.LATENT_TO_PIXEL_RATIO)
encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks(
latents=x,
@ -678,10 +677,6 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states = self.img_in(hidden_states)
if transformer_options is None:
transformer_options = {}
transformer_options["eligen_attention_mask"] = eligen_attention_mask
del img_ids
else:
@ -713,9 +708,25 @@ class QwenImageTransformer2DModel(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
out["txt"], out["img"] = block(
hidden_states=args["img"],
encoder_hidden_states=args["txt"],
encoder_hidden_states_mask=args.get("encoder_hidden_states_mask"),
temb=args["vec"],
image_rotary_emb=args["pe"],
attention_mask=args.get("attention_mask"),
transformer_options=args["transformer_options"]
)
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({
"img": hidden_states,
"txt": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"attention_mask": eligen_attention_mask,
"vec": temb,
"pe": image_rotary_emb,
"transformer_options": transformer_options
}, {"original_block": block_wrap})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
else:
@ -725,6 +736,7 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_mask=eligen_attention_mask,
transformer_options=transformer_options,
)