mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Tokenizers are now shallow cloned when CLIP is cloned. This allows nodes to add vocab to the tokenizer, as some checkpoints and LoRAs may require.
This commit is contained in:
parent
8741cb3ce8
commit
5a9055fe05
@ -113,7 +113,8 @@ class CLIP:
|
||||
n = CLIP(no_init=True)
|
||||
n.patcher = self.patcher.clone()
|
||||
n.cond_stage_model = self.cond_stage_model
|
||||
n.tokenizer = self.tokenizer
|
||||
# cloning the tokenizer allows the vocab updates to work more idiomatically
|
||||
n.tokenizer = self.tokenizer.clone()
|
||||
n.layer_idx = self.layer_idx
|
||||
return n
|
||||
|
||||
|
||||
@ -1,16 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import importlib.resources as resources
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
import zipfile
|
||||
from typing import List
|
||||
from typing import Tuple, Sequence, TypeVar
|
||||
|
||||
import torch
|
||||
from pkg_resources import resource_filename
|
||||
from transformers import CLIPTokenizer
|
||||
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from . import clip_model
|
||||
from . import model_management
|
||||
@ -392,6 +393,9 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
return embed_out
|
||||
|
||||
|
||||
SDTokenizerT = TypeVar('SDTokenizerT', bound='SDTokenizer')
|
||||
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None):
|
||||
if tokenizer_path is None:
|
||||
@ -399,7 +403,9 @@ class SDTokenizer:
|
||||
if not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
|
||||
# package based
|
||||
tokenizer_path = resource_filename('comfy', 'sd1_tokenizer/')
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
||||
self.tokenizer_class = tokenizer_class
|
||||
self.tokenizer_path = tokenizer_path
|
||||
self.tokenizer: PreTrainedTokenizerBase = tokenizer_class.from_pretrained(tokenizer_path)
|
||||
self.max_length = max_length
|
||||
self.min_length = min_length
|
||||
|
||||
@ -414,6 +420,7 @@ class SDTokenizer:
|
||||
self.end_token = empty[0]
|
||||
self.pad_with_end = pad_with_end
|
||||
self.pad_to_max_length = pad_to_max_length
|
||||
self.additional_tokens: Tuple[str, ...] = ()
|
||||
self.add_tokens([])
|
||||
self.embedding_directory = embedding_directory
|
||||
self.max_word_length = 8
|
||||
@ -421,9 +428,17 @@ class SDTokenizer:
|
||||
self.embedding_size = embedding_size
|
||||
self.embedding_key = embedding_key
|
||||
|
||||
def add_tokens(self, tokens: List[str]):
|
||||
def clone(self) -> SDTokenizerT:
|
||||
sd_tokenizer = copy.copy(self)
|
||||
# correctly copy additional vocab
|
||||
sd_tokenizer.tokenizer = self.tokenizer_class.from_pretrained(self.tokenizer_path)
|
||||
sd_tokenizer.add_tokens(sd_tokenizer.additional_tokens)
|
||||
return sd_tokenizer
|
||||
|
||||
def add_tokens(self, tokens: Sequence[str]):
|
||||
self.additional_tokens += tuple(tokens)
|
||||
if len(tokens) > 0:
|
||||
self.tokenizer.add_tokens(tokens)
|
||||
self.tokenizer.add_tokens(list(tokens))
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
|
||||
@ -536,11 +551,14 @@ class SDTokenizer:
|
||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||
|
||||
|
||||
SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer")
|
||||
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
|
||||
self.sd_tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False):
|
||||
out = {}
|
||||
@ -554,6 +572,15 @@ class SD1Tokenizer:
|
||||
def sd_tokenizer(self) -> SDTokenizer:
|
||||
return getattr(self, self.clip)
|
||||
|
||||
@sd_tokenizer.setter
|
||||
def sd_tokenizer(self, value):
|
||||
setattr(self, self.clip, value)
|
||||
|
||||
def clone(self) -> SD1TokenizerT:
|
||||
sd1_tokenizer = copy.copy(self)
|
||||
sd1_tokenizer.sd_tokenizer = self.sd_tokenizer.clone()
|
||||
return sd1_tokenizer
|
||||
|
||||
|
||||
class SD1ClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, **kwargs):
|
||||
|
||||
3
comfy/web/extensions/.gitignore
vendored
3
comfy/web/extensions/.gitignore
vendored
@ -1,3 +0,0 @@
|
||||
*
|
||||
!core/
|
||||
!logging.js.example
|
||||
@ -48,6 +48,7 @@ class TextDiffuserTokens(CustomNode):
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, clip: CLIP):
|
||||
clip = clip.clone()
|
||||
if len(TextDiffuserTokens.TOKENS) == 0:
|
||||
for i in range(520):
|
||||
TextDiffuserTokens.TOKENS.append(f'l{i}</w>')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user