mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 22:00:49 +08:00
52 lines
1.6 KiB
Python
52 lines
1.6 KiB
Python
import copy
|
|
from pathlib import Path
|
|
|
|
import sentencepiece
|
|
import torch
|
|
|
|
|
|
class SPieceTokenizer:
|
|
@staticmethod
|
|
def from_pretrained(path, **kwargs):
|
|
return SPieceTokenizer(path, **kwargs)
|
|
|
|
def __init__(self, tokenizer_path: bytes | str | Path, add_bos=False, add_eos=True, **kwargs):
|
|
self.add_bos = add_bos
|
|
self.add_eos = add_eos
|
|
if torch.is_tensor(tokenizer_path):
|
|
tokenizer_path = tokenizer_path.numpy().tobytes()
|
|
|
|
construction_args = {
|
|
'add_bos': self.add_bos,
|
|
'add_eos': self.add_eos,
|
|
}
|
|
|
|
if isinstance(tokenizer_path, bytes):
|
|
construction_args["model_proto"] = tokenizer_path
|
|
else:
|
|
if not Path(tokenizer_path).is_file():
|
|
raise ValueError(f"invalid tokenizer {tokenizer_path}")
|
|
construction_args["model_file"] = tokenizer_path
|
|
|
|
self.tokenizer = sentencepiece.SentencePieceProcessor(**construction_args) # pylint: disable=unexpected-keyword-arg
|
|
|
|
self.end = self.tokenizer.eos_id()
|
|
self.eos_token_id = self.end
|
|
self.eos_token = self.tokenizer.id_to_piece(self.eos_token_id)
|
|
self._vocab = {
|
|
self.tokenizer.id_to_piece(i): i for i in range(self.tokenizer.get_piece_size())
|
|
}
|
|
|
|
def get_vocab(self):
|
|
return self._vocab
|
|
|
|
def __call__(self, string):
|
|
out = self.tokenizer.encode(string)
|
|
return {"input_ids": out}
|
|
|
|
def serialize_model(self):
|
|
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))
|
|
|
|
def clone(self):
|
|
return copy.copy(self)
|