ComfyUI/comfy/text_encoders/spiece_tokenizer.py
2025-09-18 13:42:05 -07:00

52 lines
1.6 KiB
Python

import copy
from pathlib import Path
import sentencepiece
import torch
class SPieceTokenizer:
@staticmethod
def from_pretrained(path, **kwargs):
return SPieceTokenizer(path, **kwargs)
def __init__(self, tokenizer_path: bytes | str | Path, add_bos=False, add_eos=True, **kwargs):
self.add_bos = add_bos
self.add_eos = add_eos
if torch.is_tensor(tokenizer_path):
tokenizer_path = tokenizer_path.numpy().tobytes()
construction_args = {
'add_bos': self.add_bos,
'add_eos': self.add_eos,
}
if isinstance(tokenizer_path, bytes):
construction_args["model_proto"] = tokenizer_path
else:
if not Path(tokenizer_path).is_file():
raise ValueError(f"invalid tokenizer {tokenizer_path}")
construction_args["model_file"] = tokenizer_path
self.tokenizer = sentencepiece.SentencePieceProcessor(**construction_args) # pylint: disable=unexpected-keyword-arg
self.end = self.tokenizer.eos_id()
self.eos_token_id = self.end
self.eos_token = self.tokenizer.id_to_piece(self.eos_token_id)
self._vocab = {
self.tokenizer.id_to_piece(i): i for i in range(self.tokenizer.get_piece_size())
}
def get_vocab(self):
return self._vocab
def __call__(self, string):
out = self.tokenizer.encode(string)
return {"input_ids": out}
def serialize_model(self):
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))
def clone(self):
return copy.copy(self)