mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 04:50:49 +08:00
added attention_mask to QwenImageTransformerBlock.forward()
This commit is contained in:
parent
16adfe2153
commit
9792606847
@ -159,27 +159,21 @@ class Attention(nn.Module):
|
|||||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||||
|
|
||||||
# Apply EliGen attention mask if present
|
# Validate attention mask shape if provided
|
||||||
effective_mask = attention_mask
|
if attention_mask is not None:
|
||||||
if transformer_options is not None:
|
expected_seq = joint_query.shape[1]
|
||||||
eligen_mask = transformer_options.get("eligen_attention_mask", None)
|
if attention_mask.shape[-1] != expected_seq:
|
||||||
if eligen_mask is not None:
|
raise ValueError(
|
||||||
effective_mask = eligen_mask
|
f"Attention mask shape mismatch: {attention_mask.shape} "
|
||||||
|
f"doesn't match sequence length {expected_seq}"
|
||||||
# Validate shape
|
)
|
||||||
expected_seq = joint_query.shape[1]
|
|
||||||
if eligen_mask.shape[-1] != expected_seq:
|
|
||||||
raise ValueError(
|
|
||||||
f"EliGen attention mask shape mismatch: {eligen_mask.shape} "
|
|
||||||
f"doesn't match sequence length {expected_seq}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use ComfyUI's optimized attention
|
# Use ComfyUI's optimized attention
|
||||||
joint_query = joint_query.flatten(start_dim=2)
|
joint_query = joint_query.flatten(start_dim=2)
|
||||||
joint_key = joint_key.flatten(start_dim=2)
|
joint_key = joint_key.flatten(start_dim=2)
|
||||||
joint_value = joint_value.flatten(start_dim=2)
|
joint_value = joint_value.flatten(start_dim=2)
|
||||||
|
|
||||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, effective_mask, transformer_options=transformer_options)
|
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||||
@ -246,6 +240,7 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
encoder_hidden_states_mask: torch.Tensor,
|
encoder_hidden_states_mask: torch.Tensor,
|
||||||
temb: torch.Tensor,
|
temb: torch.Tensor,
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
img_mod_params = self.img_mod(temb)
|
img_mod_params = self.img_mod(temb)
|
||||||
@ -262,6 +257,7 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
hidden_states=img_modulated,
|
hidden_states=img_modulated,
|
||||||
encoder_hidden_states=txt_modulated,
|
encoder_hidden_states=txt_modulated,
|
||||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
|
attention_mask=attention_mask,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
@ -640,6 +636,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
|
|
||||||
|
# Initialize attention mask (None for standard generation)
|
||||||
|
eligen_attention_mask = None
|
||||||
|
|
||||||
# Extract EliGen entity data
|
# Extract EliGen entity data
|
||||||
entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
|
entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
|
||||||
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
|
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
|
||||||
@ -659,8 +658,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
|
|
||||||
if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond:
|
if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond:
|
||||||
# EliGen path
|
# EliGen path
|
||||||
height = int(orig_shape[-2] * 8)
|
height = int(orig_shape[-2] * self.LATENT_TO_PIXEL_RATIO)
|
||||||
width = int(orig_shape[-1] * 8)
|
width = int(orig_shape[-1] * self.LATENT_TO_PIXEL_RATIO)
|
||||||
|
|
||||||
encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks(
|
encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks(
|
||||||
latents=x,
|
latents=x,
|
||||||
@ -678,10 +677,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
|
|
||||||
hidden_states = self.img_in(hidden_states)
|
hidden_states = self.img_in(hidden_states)
|
||||||
|
|
||||||
if transformer_options is None:
|
|
||||||
transformer_options = {}
|
|
||||||
transformer_options["eligen_attention_mask"] = eligen_attention_mask
|
|
||||||
|
|
||||||
del img_ids
|
del img_ids
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -713,9 +708,25 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
|
out["txt"], out["img"] = block(
|
||||||
|
hidden_states=args["img"],
|
||||||
|
encoder_hidden_states=args["txt"],
|
||||||
|
encoder_hidden_states_mask=args.get("encoder_hidden_states_mask"),
|
||||||
|
temb=args["vec"],
|
||||||
|
image_rotary_emb=args["pe"],
|
||||||
|
attention_mask=args.get("attention_mask"),
|
||||||
|
transformer_options=args["transformer_options"]
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({
|
||||||
|
"img": hidden_states,
|
||||||
|
"txt": encoder_hidden_states,
|
||||||
|
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||||
|
"attention_mask": eligen_attention_mask,
|
||||||
|
"vec": temb,
|
||||||
|
"pe": image_rotary_emb,
|
||||||
|
"transformer_options": transformer_options
|
||||||
|
}, {"original_block": block_wrap})
|
||||||
hidden_states = out["img"]
|
hidden_states = out["img"]
|
||||||
encoder_hidden_states = out["txt"]
|
encoder_hidden_states = out["txt"]
|
||||||
else:
|
else:
|
||||||
@ -725,6 +736,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
attention_mask=eligen_attention_mask,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user