This commit is contained in:
mathbbN 2026-07-02 10:50:42 -04:00 committed by GitHub
commit 44252e994e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -937,15 +937,21 @@ class BaseGenerate:
return torch.argmax(logits, dim=-1, keepdim=True) return torch.argmax(logits, dim=-1, keepdim=True)
# Sampling mode # Sampling mode
if repetition_penalty != 1.0: apply_repetition_penalty = repetition_penalty != 1.0
for i in range(logits.shape[0]): apply_presence_penalty = presence_penalty is not None and presence_penalty != 0.0
for token_id in set(token_history): if (apply_repetition_penalty or apply_presence_penalty) and token_history:
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty # Vectorized equivalent of looping over set(token_history) for every batch row.
# The original nested Python loop scales as O(len(history)) per generated token and
if presence_penalty is not None and presence_penalty != 0.0: # indexes the logits tensor with scalars, which forces a GPU->CPU sync each step.
for i in range(logits.shape[0]): # Gathering the affected columns once and scattering them back keeps the per-element
for token_id in set(token_history): # arithmetic identical while running entirely on-device.
logits[i, token_id] -= presence_penalty unique_tokens = torch.as_tensor(sorted(set(token_history)), device=logits.device, dtype=torch.long)
penalized = logits.index_select(1, unique_tokens)
if apply_repetition_penalty:
penalized = torch.where(penalized < 0, penalized * repetition_penalty, penalized * (1.0 / repetition_penalty))
if apply_presence_penalty:
penalized = penalized - presence_penalty
logits.index_copy_(1, unique_tokens, penalized)
if temperature != 1.0: if temperature != 1.0:
logits = logits / temperature logits = logits / temperature