mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-24 09:22:32 +08:00
llama: use a more efficient rope implementation (#12434)
Get rid of the cat and unary negation and inplace add-cmul the two halves of the rope. Precompute -sin once at the start of the model rather than every transformer block. This is slightly faster on both GPU and CPU bound setups.
This commit is contained in:
parent
117e214354
commit
ae79e33345
@ -355,13 +355,6 @@ class RMSNorm(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
||||||
if not isinstance(theta, list):
|
if not isinstance(theta, list):
|
||||||
theta = [theta]
|
theta = [theta]
|
||||||
@ -390,20 +383,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
|
|||||||
else:
|
else:
|
||||||
cos = cos.unsqueeze(1)
|
cos = cos.unsqueeze(1)
|
||||||
sin = sin.unsqueeze(1)
|
sin = sin.unsqueeze(1)
|
||||||
out.append((cos, sin))
|
sin_split = sin.shape[-1] // 2
|
||||||
|
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
|
||||||
|
|
||||||
if len(out) == 1:
|
if len(out) == 1:
|
||||||
return out[0]
|
return out[0]
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def apply_rope(xq, xk, freqs_cis):
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
org_dtype = xq.dtype
|
org_dtype = xq.dtype
|
||||||
cos = freqs_cis[0]
|
cos = freqs_cis[0]
|
||||||
sin = freqs_cis[1]
|
sin = freqs_cis[1]
|
||||||
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
nsin = freqs_cis[2]
|
||||||
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
|
||||||
|
q_embed = (xq * cos)
|
||||||
|
q_split = q_embed.shape[-1] // 2
|
||||||
|
q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
|
||||||
|
q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
|
||||||
|
|
||||||
|
k_embed = (xk * cos)
|
||||||
|
k_split = k_embed.shape[-1] // 2
|
||||||
|
k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
|
||||||
|
k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
|
||||||
|
|
||||||
return q_embed.to(org_dtype), k_embed.to(org_dtype)
|
return q_embed.to(org_dtype), k_embed.to(org_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user