Tokenizers are now shallow cloned when CLIP is cloned. This allows nodes to add vocab to the tokenizer, as some checkpoints and LoRAs may require.

This commit is contained in:
doctorpangloss 2024-05-16 12:24:50 -07:00
parent 8741cb3ce8
commit 5a9055fe05
4 changed files with 36 additions and 10 deletions

View File

@ -113,7 +113,8 @@ class CLIP:
n = CLIP(no_init=True)
n.patcher = self.patcher.clone()
n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer
# cloning the tokenizer allows the vocab updates to work more idiomatically
n.tokenizer = self.tokenizer.clone()
n.layer_idx = self.layer_idx
return n

View File

@ -1,16 +1,17 @@
from __future__ import annotations
import copy
import importlib.resources as resources
import json
import logging
import os
import traceback
import zipfile
from typing import List
from typing import Tuple, Sequence, TypeVar
import torch
from pkg_resources import resource_filename
from transformers import CLIPTokenizer
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
from . import clip_model
from . import model_management
@ -392,6 +393,9 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
SDTokenizerT = TypeVar('SDTokenizerT', bound='SDTokenizer')
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None):
if tokenizer_path is None:
@ -399,7 +403,9 @@ class SDTokenizer:
if not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
# package based
tokenizer_path = resource_filename('comfy', 'sd1_tokenizer/')
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
self.tokenizer_class = tokenizer_class
self.tokenizer_path = tokenizer_path
self.tokenizer: PreTrainedTokenizerBase = tokenizer_class.from_pretrained(tokenizer_path)
self.max_length = max_length
self.min_length = min_length
@ -414,6 +420,7 @@ class SDTokenizer:
self.end_token = empty[0]
self.pad_with_end = pad_with_end
self.pad_to_max_length = pad_to_max_length
self.additional_tokens: Tuple[str, ...] = ()
self.add_tokens([])
self.embedding_directory = embedding_directory
self.max_word_length = 8
@ -421,9 +428,17 @@ class SDTokenizer:
self.embedding_size = embedding_size
self.embedding_key = embedding_key
def add_tokens(self, tokens: List[str]):
def clone(self) -> SDTokenizerT:
sd_tokenizer = copy.copy(self)
# correctly copy additional vocab
sd_tokenizer.tokenizer = self.tokenizer_class.from_pretrained(self.tokenizer_path)
sd_tokenizer.add_tokens(sd_tokenizer.additional_tokens)
return sd_tokenizer
def add_tokens(self, tokens: Sequence[str]):
self.additional_tokens += tuple(tokens)
if len(tokens) > 0:
self.tokenizer.add_tokens(tokens)
self.tokenizer.add_tokens(list(tokens))
vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
@ -536,11 +551,14 @@ class SDTokenizer:
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer")
class SD1Tokenizer:
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
self.sd_tokenizer = tokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text: str, return_word_ids=False):
out = {}
@ -554,6 +572,15 @@ class SD1Tokenizer:
def sd_tokenizer(self) -> SDTokenizer:
return getattr(self, self.clip)
@sd_tokenizer.setter
def sd_tokenizer(self, value):
setattr(self, self.clip, value)
def clone(self) -> SD1TokenizerT:
sd1_tokenizer = copy.copy(self)
sd1_tokenizer.sd_tokenizer = self.sd_tokenizer.clone()
return sd1_tokenizer
class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, **kwargs):

View File

@ -1,3 +0,0 @@
*
!core/
!logging.js.example

View File

@ -48,6 +48,7 @@ class TextDiffuserTokens(CustomNode):
FUNCTION = "execute"
def execute(self, clip: CLIP):
clip = clip.clone()
if len(TextDiffuserTokens.TOKENS) == 0:
for i in range(520):
TextDiffuserTokens.TOKENS.append(f'l{i}</w>')