diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 78ad81741..68d67ef05 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -475,30 +475,17 @@ def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=No def _apply_vision_2d_rope(x, freqs): """Apply 2D RoPE (multidimensional) to vision query/key states. - Splits x and cos/sin into ndim=2 parts, applies rotate_half RoPE to each independently. + Splits x and cos/sin into ndim=2 parts, applies 1D RoPE to each independently. x: [batch, heads, seq, head_dim] freqs: (cos, sin) each [batch, seq, head_dim] """ cos = freqs[0].unsqueeze(1) # [batch, 1, seq, head_dim] sin = freqs[1].unsqueeze(1) - - def rotate_half(t): - t1 = t[..., :t.shape[-1]//2] - t2 = t[..., t.shape[-1]//2:] - return torch.cat((-t2, t1), dim=-1) - - # Split into 2 parts (y and x dimensions) half = x.shape[-1] // 2 - x_parts = [x[..., :half], x[..., half:]] - cos_parts = [cos[..., :half], cos[..., half:]] - sin_parts = [sin[..., :half], sin[..., half:]] - - rotated_parts = [] - for xp, cp, sp in zip(x_parts, cos_parts, sin_parts): - rotated_parts.append((xp * cp) + (rotate_half(xp) * sp)) - - return torch.cat(rotated_parts, dim=-1) + a = _apply_rotary_pos_emb(x[..., :half], (cos[..., :half], sin[..., :half])) + b = _apply_rotary_pos_emb(x[..., half:], (cos[..., half:], sin[..., half:])) + return torch.cat([a, b], dim=-1) class ClippedLinear(nn.Module): @@ -622,10 +609,8 @@ class Gemma4PatchEmbedder(nn.Module): hidden_states = self.input_proj((2.0 * (patches - 0.5)).to(self.input_proj.weight.dtype)) clamped_positions = pixel_position_ids.clamp(min=0) - one_hot = torch.nn.functional.one_hot(clamped_positions, num_classes=self.position_embedding_size) pos_table = comfy.model_management.cast_to_device(self.position_embedding_table, hidden_states.device, hidden_states.dtype) - one_hot = one_hot.permute(0, 2, 1, 3).to(pos_table) - position_embeddings = (one_hot @ pos_table).sum(dim=1) + position_embeddings = pos_table[0][clamped_positions[..., 0]] + pos_table[1][clamped_positions[..., 1]] # Zero out position embeddings for padding patches (matching HF) padding_positions = (pixel_position_ids == -1).all(dim=-1)