Tokenizers are now shallow cloned when CLIP is cloned. This allows nodes to add vocab to the tokenizer, as some checkpoints and LoRAs may require.

This commit is contained in:
doctorpangloss 2024-05-16 12:24:50 -07:00
parent 8741cb3ce8
commit 5a9055fe05
4 changed files with 36 additions and 10 deletions

View File

@ -113,7 +113,8 @@ class CLIP:
n = CLIP(no_init=True) n = CLIP(no_init=True)
n.patcher = self.patcher.clone() n.patcher = self.patcher.clone()
n.cond_stage_model = self.cond_stage_model n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer # cloning the tokenizer allows the vocab updates to work more idiomatically
n.tokenizer = self.tokenizer.clone()
n.layer_idx = self.layer_idx n.layer_idx = self.layer_idx
return n return n

View File

@ -1,16 +1,17 @@
from __future__ import annotations from __future__ import annotations
import copy
import importlib.resources as resources import importlib.resources as resources
import json import json
import logging import logging
import os import os
import traceback import traceback
import zipfile import zipfile
from typing import List from typing import Tuple, Sequence, TypeVar
import torch import torch
from pkg_resources import resource_filename from pkg_resources import resource_filename
from transformers import CLIPTokenizer from transformers import CLIPTokenizer, PreTrainedTokenizerBase
from . import clip_model from . import clip_model
from . import model_management from . import model_management
@ -392,6 +393,9 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out return embed_out
SDTokenizerT = TypeVar('SDTokenizerT', bound='SDTokenizer')
class SDTokenizer: class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None): def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None):
if tokenizer_path is None: if tokenizer_path is None:
@ -399,7 +403,9 @@ class SDTokenizer:
if not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")): if not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
# package based # package based
tokenizer_path = resource_filename('comfy', 'sd1_tokenizer/') tokenizer_path = resource_filename('comfy', 'sd1_tokenizer/')
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path) self.tokenizer_class = tokenizer_class
self.tokenizer_path = tokenizer_path
self.tokenizer: PreTrainedTokenizerBase = tokenizer_class.from_pretrained(tokenizer_path)
self.max_length = max_length self.max_length = max_length
self.min_length = min_length self.min_length = min_length
@ -414,6 +420,7 @@ class SDTokenizer:
self.end_token = empty[0] self.end_token = empty[0]
self.pad_with_end = pad_with_end self.pad_with_end = pad_with_end
self.pad_to_max_length = pad_to_max_length self.pad_to_max_length = pad_to_max_length
self.additional_tokens: Tuple[str, ...] = ()
self.add_tokens([]) self.add_tokens([])
self.embedding_directory = embedding_directory self.embedding_directory = embedding_directory
self.max_word_length = 8 self.max_word_length = 8
@ -421,9 +428,17 @@ class SDTokenizer:
self.embedding_size = embedding_size self.embedding_size = embedding_size
self.embedding_key = embedding_key self.embedding_key = embedding_key
def add_tokens(self, tokens: List[str]): def clone(self) -> SDTokenizerT:
sd_tokenizer = copy.copy(self)
# correctly copy additional vocab
sd_tokenizer.tokenizer = self.tokenizer_class.from_pretrained(self.tokenizer_path)
sd_tokenizer.add_tokens(sd_tokenizer.additional_tokens)
return sd_tokenizer
def add_tokens(self, tokens: Sequence[str]):
self.additional_tokens += tuple(tokens)
if len(tokens) > 0: if len(tokens) > 0:
self.tokenizer.add_tokens(tokens) self.tokenizer.add_tokens(list(tokens))
vocab = self.tokenizer.get_vocab() vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()} self.inv_vocab = {v: k for k, v in vocab.items()}
@ -536,11 +551,14 @@ class SDTokenizer:
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer")
class SD1Tokenizer: class SD1Tokenizer:
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer): def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name) self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory)) self.sd_tokenizer = tokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text: str, return_word_ids=False): def tokenize_with_weights(self, text: str, return_word_ids=False):
out = {} out = {}
@ -554,6 +572,15 @@ class SD1Tokenizer:
def sd_tokenizer(self) -> SDTokenizer: def sd_tokenizer(self) -> SDTokenizer:
return getattr(self, self.clip) return getattr(self, self.clip)
@sd_tokenizer.setter
def sd_tokenizer(self, value):
setattr(self, self.clip, value)
def clone(self) -> SD1TokenizerT:
sd1_tokenizer = copy.copy(self)
sd1_tokenizer.sd_tokenizer = self.sd_tokenizer.clone()
return sd1_tokenizer
class SD1ClipModel(torch.nn.Module): class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, **kwargs): def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, **kwargs):

View File

@ -1,3 +0,0 @@
*
!core/
!logging.js.example

View File

@ -48,6 +48,7 @@ class TextDiffuserTokens(CustomNode):
FUNCTION = "execute" FUNCTION = "execute"
def execute(self, clip: CLIP): def execute(self, clip: CLIP):
clip = clip.clone()
if len(TextDiffuserTokens.TOKENS) == 0: if len(TextDiffuserTokens.TOKENS) == 0:
for i in range(520): for i in range(520):
TextDiffuserTokens.TOKENS.append(f'l{i}</w>') TextDiffuserTokens.TOKENS.append(f'l{i}</w>')