From 5a9055fe054a474d49736a6e6a32b8d8508d183a Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Thu, 16 May 2024 12:24:50 -0700 Subject: [PATCH] Tokenizers are now shallow cloned when CLIP is cloned. This allows nodes to add vocab to the tokenizer, as some checkpoints and LoRAs may require. --- comfy/sd.py | 3 +- comfy/sd1_clip.py | 39 +++++++++++++++++++---- comfy/web/extensions/.gitignore | 3 -- comfy_extras/nodes/nodes_textdiffusers.py | 1 + 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index b108a5775..46dedff2c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -113,7 +113,8 @@ class CLIP: n = CLIP(no_init=True) n.patcher = self.patcher.clone() n.cond_stage_model = self.cond_stage_model - n.tokenizer = self.tokenizer + # cloning the tokenizer allows the vocab updates to work more idiomatically + n.tokenizer = self.tokenizer.clone() n.layer_idx = self.layer_idx return n diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 6afc9cdff..ef134cc59 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -1,16 +1,17 @@ from __future__ import annotations +import copy import importlib.resources as resources import json import logging import os import traceback import zipfile -from typing import List +from typing import Tuple, Sequence, TypeVar import torch from pkg_resources import resource_filename -from transformers import CLIPTokenizer +from transformers import CLIPTokenizer, PreTrainedTokenizerBase from . import clip_model from . import model_management @@ -392,6 +393,9 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out +SDTokenizerT = TypeVar('SDTokenizerT', bound='SDTokenizer') + + class SDTokenizer: def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None): if tokenizer_path is None: @@ -399,7 +403,9 @@ class SDTokenizer: if not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")): # package based tokenizer_path = resource_filename('comfy', 'sd1_tokenizer/') - self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path) + self.tokenizer_class = tokenizer_class + self.tokenizer_path = tokenizer_path + self.tokenizer: PreTrainedTokenizerBase = tokenizer_class.from_pretrained(tokenizer_path) self.max_length = max_length self.min_length = min_length @@ -414,6 +420,7 @@ class SDTokenizer: self.end_token = empty[0] self.pad_with_end = pad_with_end self.pad_to_max_length = pad_to_max_length + self.additional_tokens: Tuple[str, ...] = () self.add_tokens([]) self.embedding_directory = embedding_directory self.max_word_length = 8 @@ -421,9 +428,17 @@ class SDTokenizer: self.embedding_size = embedding_size self.embedding_key = embedding_key - def add_tokens(self, tokens: List[str]): + def clone(self) -> SDTokenizerT: + sd_tokenizer = copy.copy(self) + # correctly copy additional vocab + sd_tokenizer.tokenizer = self.tokenizer_class.from_pretrained(self.tokenizer_path) + sd_tokenizer.add_tokens(sd_tokenizer.additional_tokens) + return sd_tokenizer + + def add_tokens(self, tokens: Sequence[str]): + self.additional_tokens += tuple(tokens) if len(tokens) > 0: - self.tokenizer.add_tokens(tokens) + self.tokenizer.add_tokens(list(tokens)) vocab = self.tokenizer.get_vocab() self.inv_vocab = {v: k for k, v in vocab.items()} @@ -536,11 +551,14 @@ class SDTokenizer: return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) +SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer") + + class SD1Tokenizer: def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer): self.clip_name = clip_name self.clip = "clip_{}".format(self.clip_name) - setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory)) + self.sd_tokenizer = tokenizer(embedding_directory=embedding_directory) def tokenize_with_weights(self, text: str, return_word_ids=False): out = {} @@ -554,6 +572,15 @@ class SD1Tokenizer: def sd_tokenizer(self) -> SDTokenizer: return getattr(self, self.clip) + @sd_tokenizer.setter + def sd_tokenizer(self, value): + setattr(self, self.clip, value) + + def clone(self) -> SD1TokenizerT: + sd1_tokenizer = copy.copy(self) + sd1_tokenizer.sd_tokenizer = self.sd_tokenizer.clone() + return sd1_tokenizer + class SD1ClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, **kwargs): diff --git a/comfy/web/extensions/.gitignore b/comfy/web/extensions/.gitignore index 1c75447a9..e69de29bb 100644 --- a/comfy/web/extensions/.gitignore +++ b/comfy/web/extensions/.gitignore @@ -1,3 +0,0 @@ -* -!core/ -!logging.js.example \ No newline at end of file diff --git a/comfy_extras/nodes/nodes_textdiffusers.py b/comfy_extras/nodes/nodes_textdiffusers.py index 13d977afb..bd253f9c5 100644 --- a/comfy_extras/nodes/nodes_textdiffusers.py +++ b/comfy_extras/nodes/nodes_textdiffusers.py @@ -48,6 +48,7 @@ class TextDiffuserTokens(CustomNode): FUNCTION = "execute" def execute(self, clip: CLIP): + clip = clip.clone() if len(TextDiffuserTokens.TOKENS) == 0: for i in range(520): TextDiffuserTokens.TOKENS.append(f'l{i}')