mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Vectorize repetition/presence penalty in BaseGenerate.sample_token
The per-token sampling penalties were applied with a nested Python loop over set(token_history) for each batch row. That loop grows with the generated sequence length and indexes the logits tensor with scalars, forcing a GPU->CPU sync on every decode step. Replace it with a single gather/scatter over the unique history tokens. The per-element arithmetic is unchanged, so the sampled logits are bit-for-bit identical, while the work runs entirely on-device and no longer scales with history length.
This commit is contained in:
parent
7cb784e0f4
commit
2bb8d10e78
@ -937,15 +937,21 @@ class BaseGenerate:
|
||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||
|
||||
# Sampling mode
|
||||
if repetition_penalty != 1.0:
|
||||
for i in range(logits.shape[0]):
|
||||
for token_id in set(token_history):
|
||||
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
|
||||
|
||||
if presence_penalty is not None and presence_penalty != 0.0:
|
||||
for i in range(logits.shape[0]):
|
||||
for token_id in set(token_history):
|
||||
logits[i, token_id] -= presence_penalty
|
||||
apply_repetition_penalty = repetition_penalty != 1.0
|
||||
apply_presence_penalty = presence_penalty is not None and presence_penalty != 0.0
|
||||
if (apply_repetition_penalty or apply_presence_penalty) and token_history:
|
||||
# Vectorized equivalent of looping over set(token_history) for every batch row.
|
||||
# The original nested Python loop scales as O(len(history)) per generated token and
|
||||
# indexes the logits tensor with scalars, which forces a GPU->CPU sync each step.
|
||||
# Gathering the affected columns once and scattering them back keeps the per-element
|
||||
# arithmetic identical while running entirely on-device.
|
||||
unique_tokens = torch.as_tensor(sorted(set(token_history)), device=logits.device, dtype=torch.long)
|
||||
penalized = logits.index_select(1, unique_tokens)
|
||||
if apply_repetition_penalty:
|
||||
penalized = torch.where(penalized < 0, penalized * repetition_penalty, penalized * (1.0 / repetition_penalty))
|
||||
if apply_presence_penalty:
|
||||
penalized = penalized - presence_penalty
|
||||
logits.index_copy_(1, unique_tokens, penalized)
|
||||
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
Loading…
Reference in New Issue
Block a user