From 9792606847d7799ae98f2156c73e915e58a52c44 Mon Sep 17 00:00:00 2001 From: nolan4 Date: Tue, 4 Nov 2025 19:55:24 -0800 Subject: [PATCH] added attention_mask to QwenImageTransformerBlock.forward() --- comfy/ldm/qwen_image/model.py | 58 +++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 461dde58f..76ad3646e 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -159,27 +159,21 @@ class Attention(nn.Module): joint_query = apply_rotary_emb(joint_query, image_rotary_emb) joint_key = apply_rotary_emb(joint_key, image_rotary_emb) - # Apply EliGen attention mask if present - effective_mask = attention_mask - if transformer_options is not None: - eligen_mask = transformer_options.get("eligen_attention_mask", None) - if eligen_mask is not None: - effective_mask = eligen_mask - - # Validate shape - expected_seq = joint_query.shape[1] - if eligen_mask.shape[-1] != expected_seq: - raise ValueError( - f"EliGen attention mask shape mismatch: {eligen_mask.shape} " - f"doesn't match sequence length {expected_seq}" - ) + # Validate attention mask shape if provided + if attention_mask is not None: + expected_seq = joint_query.shape[1] + if attention_mask.shape[-1] != expected_seq: + raise ValueError( + f"Attention mask shape mismatch: {attention_mask.shape} " + f"doesn't match sequence length {expected_seq}" + ) # Use ComfyUI's optimized attention joint_query = joint_query.flatten(start_dim=2) joint_key = joint_key.flatten(start_dim=2) joint_value = joint_value.flatten(start_dim=2) - joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, effective_mask, transformer_options=transformer_options) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options) txt_attn_output = joint_hidden_states[:, :seq_txt, :] img_attn_output = joint_hidden_states[:, seq_txt:, :] @@ -246,6 +240,7 @@ class QwenImageTransformerBlock(nn.Module): encoder_hidden_states_mask: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: img_mod_params = self.img_mod(temb) @@ -262,6 +257,7 @@ class QwenImageTransformerBlock(nn.Module): hidden_states=img_modulated, encoder_hidden_states=txt_modulated, encoder_hidden_states_mask=encoder_hidden_states_mask, + attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, transformer_options=transformer_options, ) @@ -640,6 +636,9 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) + # Initialize attention mask (None for standard generation) + eligen_attention_mask = None + # Extract EliGen entity data entity_prompt_emb = kwargs.get("entity_prompt_emb", None) entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None) @@ -659,8 +658,8 @@ class QwenImageTransformer2DModel(nn.Module): if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond: # EliGen path - height = int(orig_shape[-2] * 8) - width = int(orig_shape[-1] * 8) + height = int(orig_shape[-2] * self.LATENT_TO_PIXEL_RATIO) + width = int(orig_shape[-1] * self.LATENT_TO_PIXEL_RATIO) encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks( latents=x, @@ -678,10 +677,6 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = self.img_in(hidden_states) - if transformer_options is None: - transformer_options = {} - transformer_options["eligen_attention_mask"] = eligen_attention_mask - del img_ids else: @@ -713,9 +708,25 @@ class QwenImageTransformer2DModel(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) + out["txt"], out["img"] = block( + hidden_states=args["img"], + encoder_hidden_states=args["txt"], + encoder_hidden_states_mask=args.get("encoder_hidden_states_mask"), + temb=args["vec"], + image_rotary_emb=args["pe"], + attention_mask=args.get("attention_mask"), + transformer_options=args["transformer_options"] + ) return out - out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({ + "img": hidden_states, + "txt": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "attention_mask": eligen_attention_mask, + "vec": temb, + "pe": image_rotary_emb, + "transformer_options": transformer_options + }, {"original_block": block_wrap}) hidden_states = out["img"] encoder_hidden_states = out["txt"] else: @@ -725,6 +736,7 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, + attention_mask=eligen_attention_mask, transformer_options=transformer_options, )