Code cleanup

This commit is contained in:
kijai 2026-04-21 16:44:27 +03:00
parent 35987d7061
commit 80af032762

View File

@ -475,30 +475,17 @@ def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=No
def _apply_vision_2d_rope(x, freqs):
"""Apply 2D RoPE (multidimensional) to vision query/key states.
Splits x and cos/sin into ndim=2 parts, applies rotate_half RoPE to each independently.
Splits x and cos/sin into ndim=2 parts, applies 1D RoPE to each independently.
x: [batch, heads, seq, head_dim]
freqs: (cos, sin) each [batch, seq, head_dim]
"""
cos = freqs[0].unsqueeze(1) # [batch, 1, seq, head_dim]
sin = freqs[1].unsqueeze(1)
def rotate_half(t):
t1 = t[..., :t.shape[-1]//2]
t2 = t[..., t.shape[-1]//2:]
return torch.cat((-t2, t1), dim=-1)
# Split into 2 parts (y and x dimensions)
half = x.shape[-1] // 2
x_parts = [x[..., :half], x[..., half:]]
cos_parts = [cos[..., :half], cos[..., half:]]
sin_parts = [sin[..., :half], sin[..., half:]]
rotated_parts = []
for xp, cp, sp in zip(x_parts, cos_parts, sin_parts):
rotated_parts.append((xp * cp) + (rotate_half(xp) * sp))
return torch.cat(rotated_parts, dim=-1)
a = _apply_rotary_pos_emb(x[..., :half], (cos[..., :half], sin[..., :half]))
b = _apply_rotary_pos_emb(x[..., half:], (cos[..., half:], sin[..., half:]))
return torch.cat([a, b], dim=-1)
class ClippedLinear(nn.Module):
@ -622,10 +609,8 @@ class Gemma4PatchEmbedder(nn.Module):
hidden_states = self.input_proj((2.0 * (patches - 0.5)).to(self.input_proj.weight.dtype))
clamped_positions = pixel_position_ids.clamp(min=0)
one_hot = torch.nn.functional.one_hot(clamped_positions, num_classes=self.position_embedding_size)
pos_table = comfy.model_management.cast_to_device(self.position_embedding_table, hidden_states.device, hidden_states.dtype)
one_hot = one_hot.permute(0, 2, 1, 3).to(pos_table)
position_embeddings = (one_hot @ pos_table).sum(dim=1)
position_embeddings = pos_table[0][clamped_positions[..., 0]] + pos_table[1][clamped_positions[..., 1]]
# Zero out position embeddings for padding patches (matching HF)
padding_positions = (pixel_position_ids == -1).all(dim=-1)