Handle thinking tokens different only for Gemma4

This commit is contained in:
kijai 2026-04-13 23:40:52 +03:00
parent e0cccbd4c9
commit 845eb14425
2 changed files with 10 additions and 8 deletions

View File

@ -1227,6 +1227,15 @@ class Gemma4SDTokenizer(Gemma4_Tokenizer, sd1_clip.SDTokenizer):
self.tokenizer_json_data = tokenizer_json
super().__init__(tokenizer_json, pad_with_end=False, embedding_size=self.embedding_size, embedding_key='gemma4', tokenizer_class=_Gemma4Tokenizer, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, start_token=2, tokenizer_data=tokenizer_data)
def decode(self, token_ids, **kwargs):
text = super().decode(token_ids, skip_special_tokens=False)
# Translate thinking channel markers to standard <think>/</think> tags
text = text.replace("<|channel>thought\n", "<think>\n")
text = text.replace("<channel|>", "</think>")
# Strip remaining special tokens
text = text.replace("<turn|>", "").replace("<eos>", "").strip()
return text
class Gemma4Tokenizer(sd1_clip.SD1Tokenizer):
tokenizer_class = Gemma4SDTokenizer

View File

@ -71,14 +71,7 @@ class TextGenerate(io.ComfyNode):
seed=seed
)
generated_text = clip.decode(generated_ids, skip_special_tokens=not thinking)
if thinking:
# Translate Gemma4 thinking channel markers to standard <think>/</think> tags
generated_text = generated_text.replace("<|channel>thought\n", "<think>\n")
generated_text = generated_text.replace("<channel|>", "</think>")
# Strip remaining special tokens
generated_text = generated_text.replace("<turn|>", "").replace("<eos>", "").strip()
generated_text = clip.decode(generated_ids)
return io.NodeOutput(generated_text)