mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-28 03:12:31 +08:00
Fix image encoder attention mask type
So it works with basic attention
This commit is contained in:
parent
c857b6c657
commit
e0cccbd4c9
@ -687,7 +687,11 @@ class Gemma4VisionEncoder(nn.Module):
|
|||||||
x = self.patch_embedder(patches, position_ids)
|
x = self.patch_embedder(patches, position_ids)
|
||||||
freqs = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device)
|
freqs = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device)
|
||||||
freqs = tuple(t.to(x.dtype) for t in freqs)
|
freqs = tuple(t.to(x.dtype) for t in freqs)
|
||||||
mask = (~padding).unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1) if n_padding > 0 else None
|
if n_padding > 0:
|
||||||
|
mask = padding.unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1)
|
||||||
|
mask = torch.zeros_like(mask, dtype=x.dtype).masked_fill_(mask, torch.finfo(x.dtype).min)
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
|
||||||
for layer in self.encoder.layers:
|
for layer in self.encoder.layers:
|
||||||
x = layer(x, freqs, attention_mask=mask)
|
x = layer(x, freqs, attention_mask=mask)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user