From 2bb8d10e781f438adbe23b9a9a165b052408898b Mon Sep 17 00:00:00 2001 From: nat-chan Date: Fri, 26 Jun 2026 18:43:13 +0900 Subject: [PATCH] Vectorize repetition/presence penalty in BaseGenerate.sample_token The per-token sampling penalties were applied with a nested Python loop over set(token_history) for each batch row. That loop grows with the generated sequence length and indexes the logits tensor with scalars, forcing a GPU->CPU sync on every decode step. Replace it with a single gather/scatter over the unique history tokens. The per-element arithmetic is unchanged, so the sampled logits are bit-for-bit identical, while the work runs entirely on-device and no longer scales with history length. --- comfy/text_encoders/llama.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index e9f38a9a2..1fe37db7c 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -937,15 +937,21 @@ class BaseGenerate: return torch.argmax(logits, dim=-1, keepdim=True) # Sampling mode - if repetition_penalty != 1.0: - for i in range(logits.shape[0]): - for token_id in set(token_history): - logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty - - if presence_penalty is not None and presence_penalty != 0.0: - for i in range(logits.shape[0]): - for token_id in set(token_history): - logits[i, token_id] -= presence_penalty + apply_repetition_penalty = repetition_penalty != 1.0 + apply_presence_penalty = presence_penalty is not None and presence_penalty != 0.0 + if (apply_repetition_penalty or apply_presence_penalty) and token_history: + # Vectorized equivalent of looping over set(token_history) for every batch row. + # The original nested Python loop scales as O(len(history)) per generated token and + # indexes the logits tensor with scalars, which forces a GPU->CPU sync each step. + # Gathering the affected columns once and scattering them back keeps the per-element + # arithmetic identical while running entirely on-device. + unique_tokens = torch.as_tensor(sorted(set(token_history)), device=logits.device, dtype=torch.long) + penalized = logits.index_select(1, unique_tokens) + if apply_repetition_penalty: + penalized = torch.where(penalized < 0, penalized * repetition_penalty, penalized * (1.0 / repetition_penalty)) + if apply_presence_penalty: + penalized = penalized - presence_penalty + logits.index_copy_(1, unique_tokens, penalized) if temperature != 1.0: logits = logits / temperature