ComfyUI/comfy/text_encoders/flux.py

94 lines
3.8 KiB
Python

import copy
import torch
from transformers import T5TokenizerFast
from .t5 import T5
from .. import sd1_clip, model_management
from ..component_model import files
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options=None, textmodel_json_config=None):
if model_options is None:
model_options = dict()
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json", package=__package__)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = dict()
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = dict()
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text: str, return_word_ids=False):
out = {
"l": self.clip_l.tokenize_with_weights(text, return_word_ids),
"t5xxl": self.t5xxl.tokenize_with_weights(text, return_word_ids)
}
return out
def untokenize(self, token_weight_pair):
return self.clip_l.untokenize(token_weight_pair)
def state_dict(self):
return {}
def clone(self):
return copy.copy(self)
class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options=None):
super().__init__()
if model_options is None:
model_options = {}
dtype_t5 = model_management.pick_weight_dtype(dtype_t5, dtype, device)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = {dtype, dtype_t5}
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
self.t5xxl.set_clip_options(options)
def reset_clip_options(self):
self.clip_l.reset_clip_options()
self.t5xxl.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return t5_out, l_pooled
def load_sd(self, sd):
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
else:
return self.t5xxl.load_sd(sd)
def flux_clip(dtype_t5=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_