Fix image encoder attention mask type

So it works with basic attention
This commit is contained in:
kijai 2026-04-13 23:29:27 +03:00
parent c857b6c657
commit e0cccbd4c9

View File

@ -687,7 +687,11 @@ class Gemma4VisionEncoder(nn.Module):
x = self.patch_embedder(patches, position_ids)
freqs = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device)
freqs = tuple(t.to(x.dtype) for t in freqs)
mask = (~padding).unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1) if n_padding > 0 else None
if n_padding > 0:
mask = padding.unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1)
mask = torch.zeros_like(mask, dtype=x.dtype).masked_fill_(mask, torch.finfo(x.dtype).min)
else:
mask = None
for layer in self.encoder.layers:
x = layer(x, freqs, attention_mask=mask)