diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 8905f375f..7c3df9c09 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -687,7 +687,11 @@ class Gemma4VisionEncoder(nn.Module): x = self.patch_embedder(patches, position_ids) freqs = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device) freqs = tuple(t.to(x.dtype) for t in freqs) - mask = (~padding).unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1) if n_padding > 0 else None + if n_padding > 0: + mask = padding.unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1) + mask = torch.zeros_like(mask, dtype=x.dtype).masked_fill_(mask, torch.finfo(x.dtype).min) + else: + mask = None for layer in self.encoder.layers: x = layer(x, freqs, attention_mask=mask)