mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 19:02:31 +08:00
Handle thinking tokens different only for Gemma4
This commit is contained in:
parent
e0cccbd4c9
commit
845eb14425
@ -1227,6 +1227,15 @@ class Gemma4SDTokenizer(Gemma4_Tokenizer, sd1_clip.SDTokenizer):
|
|||||||
self.tokenizer_json_data = tokenizer_json
|
self.tokenizer_json_data = tokenizer_json
|
||||||
super().__init__(tokenizer_json, pad_with_end=False, embedding_size=self.embedding_size, embedding_key='gemma4', tokenizer_class=_Gemma4Tokenizer, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, start_token=2, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer_json, pad_with_end=False, embedding_size=self.embedding_size, embedding_key='gemma4', tokenizer_class=_Gemma4Tokenizer, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, start_token=2, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
def decode(self, token_ids, **kwargs):
|
||||||
|
text = super().decode(token_ids, skip_special_tokens=False)
|
||||||
|
# Translate thinking channel markers to standard <think>/</think> tags
|
||||||
|
text = text.replace("<|channel>thought\n", "<think>\n")
|
||||||
|
text = text.replace("<channel|>", "</think>")
|
||||||
|
# Strip remaining special tokens
|
||||||
|
text = text.replace("<turn|>", "").replace("<eos>", "").strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
class Gemma4Tokenizer(sd1_clip.SD1Tokenizer):
|
class Gemma4Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
tokenizer_class = Gemma4SDTokenizer
|
tokenizer_class = Gemma4SDTokenizer
|
||||||
|
|||||||
@ -71,14 +71,7 @@ class TextGenerate(io.ComfyNode):
|
|||||||
seed=seed
|
seed=seed
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_text = clip.decode(generated_ids, skip_special_tokens=not thinking)
|
generated_text = clip.decode(generated_ids)
|
||||||
|
|
||||||
if thinking:
|
|
||||||
# Translate Gemma4 thinking channel markers to standard <think>/</think> tags
|
|
||||||
generated_text = generated_text.replace("<|channel>thought\n", "<think>\n")
|
|
||||||
generated_text = generated_text.replace("<channel|>", "</think>")
|
|
||||||
# Strip remaining special tokens
|
|
||||||
generated_text = generated_text.replace("<turn|>", "").replace("<eos>", "").strip()
|
|
||||||
|
|
||||||
return io.NodeOutput(generated_text)
|
return io.NodeOutput(generated_text)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user