mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 04:50:49 +08:00
added attention_mask to QwenImageTransformerBlock.forward()
This commit is contained in:
parent
16adfe2153
commit
9792606847
@ -159,27 +159,21 @@ class Attention(nn.Module):
|
||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||
|
||||
# Apply EliGen attention mask if present
|
||||
effective_mask = attention_mask
|
||||
if transformer_options is not None:
|
||||
eligen_mask = transformer_options.get("eligen_attention_mask", None)
|
||||
if eligen_mask is not None:
|
||||
effective_mask = eligen_mask
|
||||
|
||||
# Validate shape
|
||||
expected_seq = joint_query.shape[1]
|
||||
if eligen_mask.shape[-1] != expected_seq:
|
||||
raise ValueError(
|
||||
f"EliGen attention mask shape mismatch: {eligen_mask.shape} "
|
||||
f"doesn't match sequence length {expected_seq}"
|
||||
)
|
||||
# Validate attention mask shape if provided
|
||||
if attention_mask is not None:
|
||||
expected_seq = joint_query.shape[1]
|
||||
if attention_mask.shape[-1] != expected_seq:
|
||||
raise ValueError(
|
||||
f"Attention mask shape mismatch: {attention_mask.shape} "
|
||||
f"doesn't match sequence length {expected_seq}"
|
||||
)
|
||||
|
||||
# Use ComfyUI's optimized attention
|
||||
joint_query = joint_query.flatten(start_dim=2)
|
||||
joint_key = joint_key.flatten(start_dim=2)
|
||||
joint_value = joint_value.flatten(start_dim=2)
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, effective_mask, transformer_options=transformer_options)
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
|
||||
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||
@ -246,6 +240,7 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
encoder_hidden_states_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
img_mod_params = self.img_mod(temb)
|
||||
@ -262,6 +257,7 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
hidden_states=img_modulated,
|
||||
encoder_hidden_states=txt_modulated,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
@ -640,6 +636,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
|
||||
# Initialize attention mask (None for standard generation)
|
||||
eligen_attention_mask = None
|
||||
|
||||
# Extract EliGen entity data
|
||||
entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
|
||||
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
|
||||
@ -659,8 +658,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
|
||||
if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond:
|
||||
# EliGen path
|
||||
height = int(orig_shape[-2] * 8)
|
||||
width = int(orig_shape[-1] * 8)
|
||||
height = int(orig_shape[-2] * self.LATENT_TO_PIXEL_RATIO)
|
||||
width = int(orig_shape[-1] * self.LATENT_TO_PIXEL_RATIO)
|
||||
|
||||
encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks(
|
||||
latents=x,
|
||||
@ -678,10 +677,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
transformer_options["eligen_attention_mask"] = eligen_attention_mask
|
||||
|
||||
del img_ids
|
||||
|
||||
else:
|
||||
@ -713,9 +708,25 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
|
||||
out["txt"], out["img"] = block(
|
||||
hidden_states=args["img"],
|
||||
encoder_hidden_states=args["txt"],
|
||||
encoder_hidden_states_mask=args.get("encoder_hidden_states_mask"),
|
||||
temb=args["vec"],
|
||||
image_rotary_emb=args["pe"],
|
||||
attention_mask=args.get("attention_mask"),
|
||||
transformer_options=args["transformer_options"]
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({
|
||||
"img": hidden_states,
|
||||
"txt": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"attention_mask": eligen_attention_mask,
|
||||
"vec": temb,
|
||||
"pe": image_rotary_emb,
|
||||
"transformer_options": transformer_options
|
||||
}, {"original_block": block_wrap})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
else:
|
||||
@ -725,6 +736,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=eligen_attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user