import torch import os class SPieceTokenizer: @staticmethod def from_pretrained(path, **kwargs): return SPieceTokenizer(path, **kwargs) def __init__(self, tokenizer_path, add_bos=False, add_eos=True, special_tokens=None): self.add_bos = add_bos self.add_eos = add_eos self.special_tokens = special_tokens import sentencepiece if torch.is_tensor(tokenizer_path): tokenizer_path = tokenizer_path.numpy().tobytes() if isinstance(tokenizer_path, bytes): self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos) else: if not os.path.isfile(tokenizer_path): raise ValueError("invalid tokenizer") self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos) def get_vocab(self): out = {} for i in range(self.tokenizer.get_piece_size()): out[self.tokenizer.id_to_piece(i)] = i return out def __call__(self, string): if self.special_tokens is not None: import re special_tokens_pattern = '|'.join(re.escape(token) for token in self.special_tokens.keys()) if special_tokens_pattern and re.search(special_tokens_pattern, string): parts = re.split(f'({special_tokens_pattern})', string) result = [] for part in parts: if not part: continue if part in self.special_tokens: result.append(self.special_tokens[part]) else: encoded = self.tokenizer.encode(part, add_bos=False, add_eos=False) result.extend(encoded) return {"input_ids": result} out = self.tokenizer.encode(string) return {"input_ids": out} def decode(self, token_ids, skip_special_tokens=False): if skip_special_tokens and self.special_tokens: special_token_ids = set(self.special_tokens.values()) token_ids = [tid for tid in token_ids if tid not in special_token_ids] return self.tokenizer.decode(token_ids) def serialize_model(self): return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))