From 845eb1442552ea63ebd30b8003204fa10d43ab49 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 13 Apr 2026 23:40:52 +0300 Subject: [PATCH] Handle thinking tokens different only for Gemma4 --- comfy/text_encoders/gemma4.py | 9 +++++++++ comfy_extras/nodes_textgen.py | 9 +-------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 7c3df9c09..78ad81741 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -1227,6 +1227,15 @@ class Gemma4SDTokenizer(Gemma4_Tokenizer, sd1_clip.SDTokenizer): self.tokenizer_json_data = tokenizer_json super().__init__(tokenizer_json, pad_with_end=False, embedding_size=self.embedding_size, embedding_key='gemma4', tokenizer_class=_Gemma4Tokenizer, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, start_token=2, tokenizer_data=tokenizer_data) + def decode(self, token_ids, **kwargs): + text = super().decode(token_ids, skip_special_tokens=False) + # Translate thinking channel markers to standard / tags + text = text.replace("<|channel>thought\n", "\n") + text = text.replace("", "") + # Strip remaining special tokens + text = text.replace("", "").replace("", "").strip() + return text + class Gemma4Tokenizer(sd1_clip.SD1Tokenizer): tokenizer_class = Gemma4SDTokenizer diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py index 0d4cf3a2b..ec81159d3 100644 --- a/comfy_extras/nodes_textgen.py +++ b/comfy_extras/nodes_textgen.py @@ -71,14 +71,7 @@ class TextGenerate(io.ComfyNode): seed=seed ) - generated_text = clip.decode(generated_ids, skip_special_tokens=not thinking) - - if thinking: - # Translate Gemma4 thinking channel markers to standard / tags - generated_text = generated_text.replace("<|channel>thought\n", "\n") - generated_text = generated_text.replace("", "") - # Strip remaining special tokens - generated_text = generated_text.replace("", "").replace("", "").strip() + generated_text = clip.decode(generated_ids) return io.NodeOutput(generated_text)