mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-30 12:22:37 +08:00
Code cleanup
This commit is contained in:
parent
35987d7061
commit
80af032762
@ -475,30 +475,17 @@ def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=No
|
||||
def _apply_vision_2d_rope(x, freqs):
|
||||
"""Apply 2D RoPE (multidimensional) to vision query/key states.
|
||||
|
||||
Splits x and cos/sin into ndim=2 parts, applies rotate_half RoPE to each independently.
|
||||
Splits x and cos/sin into ndim=2 parts, applies 1D RoPE to each independently.
|
||||
|
||||
x: [batch, heads, seq, head_dim]
|
||||
freqs: (cos, sin) each [batch, seq, head_dim]
|
||||
"""
|
||||
cos = freqs[0].unsqueeze(1) # [batch, 1, seq, head_dim]
|
||||
sin = freqs[1].unsqueeze(1)
|
||||
|
||||
def rotate_half(t):
|
||||
t1 = t[..., :t.shape[-1]//2]
|
||||
t2 = t[..., t.shape[-1]//2:]
|
||||
return torch.cat((-t2, t1), dim=-1)
|
||||
|
||||
# Split into 2 parts (y and x dimensions)
|
||||
half = x.shape[-1] // 2
|
||||
x_parts = [x[..., :half], x[..., half:]]
|
||||
cos_parts = [cos[..., :half], cos[..., half:]]
|
||||
sin_parts = [sin[..., :half], sin[..., half:]]
|
||||
|
||||
rotated_parts = []
|
||||
for xp, cp, sp in zip(x_parts, cos_parts, sin_parts):
|
||||
rotated_parts.append((xp * cp) + (rotate_half(xp) * sp))
|
||||
|
||||
return torch.cat(rotated_parts, dim=-1)
|
||||
a = _apply_rotary_pos_emb(x[..., :half], (cos[..., :half], sin[..., :half]))
|
||||
b = _apply_rotary_pos_emb(x[..., half:], (cos[..., half:], sin[..., half:]))
|
||||
return torch.cat([a, b], dim=-1)
|
||||
|
||||
|
||||
class ClippedLinear(nn.Module):
|
||||
@ -622,10 +609,8 @@ class Gemma4PatchEmbedder(nn.Module):
|
||||
hidden_states = self.input_proj((2.0 * (patches - 0.5)).to(self.input_proj.weight.dtype))
|
||||
|
||||
clamped_positions = pixel_position_ids.clamp(min=0)
|
||||
one_hot = torch.nn.functional.one_hot(clamped_positions, num_classes=self.position_embedding_size)
|
||||
pos_table = comfy.model_management.cast_to_device(self.position_embedding_table, hidden_states.device, hidden_states.dtype)
|
||||
one_hot = one_hot.permute(0, 2, 1, 3).to(pos_table)
|
||||
position_embeddings = (one_hot @ pos_table).sum(dim=1)
|
||||
position_embeddings = pos_table[0][clamped_positions[..., 0]] + pos_table[1][clamped_positions[..., 1]]
|
||||
|
||||
# Zero out position embeddings for padding patches (matching HF)
|
||||
padding_positions = (pixel_position_ids == -1).all(dim=-1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user