ComfyUI/comfy/text_encoders/spiece_tokenizer.py

49 lines
1.5 KiB
Python

import copy
from pathlib import Path
import sentencepiece
import torch
class SPieceTokenizer:
@staticmethod
def from_pretrained(path, **kwargs):
return SPieceTokenizer(path, **kwargs)
def __init__(self, tokenizer_path: bytes | str | Path, add_bos=False, add_eos=True):
self.add_bos = add_bos
self.add_eos = add_eos
if torch.is_tensor(tokenizer_path):
tokenizer_path = tokenizer_path.numpy().tobytes()
construction_args = {
'add_bos': self.add_bos,
'add_eos': self.add_eos
}
if isinstance(tokenizer_path, bytes):
construction_args["model_proto"] = tokenizer_path
else:
construction_args["model_file"] = tokenizer_path
self.tokenizer = sentencepiece.SentencePieceProcessor(**construction_args) # pylint: disable=unexpected-keyword-arg
self.end = self.tokenizer.eos_id()
self.eos_token_id = self.end
self.eos_token = self.tokenizer.id_to_piece(self.eos_token_id) # pylint: disable=no-member
self._vocab = {
self.tokenizer.id_to_piece(i): i for i in range(self.tokenizer.get_piece_size()) # pylint: disable=no-member
}
def get_vocab(self):
return self._vocab
def __call__(self, string):
out = self.tokenizer.encode(string) # pylint: disable=no-member
return {"input_ids": out}
def serialize_model(self):
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))
def clone(self):
return copy.copy(self)