mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-22 20:40:49 +08:00
63 lines
2.8 KiB
Python
63 lines
2.8 KiB
Python
import torch
|
|
|
|
import comfy.model_management
|
|
import comfy.text_encoders.jina_clip_2
|
|
import comfy.text_encoders.lumina2
|
|
|
|
class NewBieTokenizer:
|
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
|
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
|
|
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
|
|
|
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
|
out = {}
|
|
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
|
|
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
|
|
return out
|
|
|
|
def untokenize(self, token_weight_pair):
|
|
raise NotImplementedError
|
|
|
|
def state_dict(self):
|
|
return {}
|
|
|
|
class NewBieTEModel(torch.nn.Module):
|
|
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}):
|
|
super().__init__()
|
|
dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device)
|
|
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
|
|
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
|
|
self.dtypes = {dtype, dtype_gemma}
|
|
|
|
def set_clip_options(self, options):
|
|
self.gemma.set_clip_options(options)
|
|
self.jina.set_clip_options(options)
|
|
|
|
def reset_clip_options(self):
|
|
self.gemma.reset_clip_options()
|
|
self.jina.reset_clip_options()
|
|
|
|
def encode_token_weights(self, token_weight_pairs):
|
|
token_weight_pairs_gemma = token_weight_pairs["gemma"]
|
|
token_weight_pairs_jina = token_weight_pairs["jina"]
|
|
|
|
gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma)
|
|
jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina)
|
|
|
|
return gemma_out, jina_pooled, gemma_extra
|
|
|
|
def load_sd(self, sd):
|
|
if "model.layers.0.self_attn.q_norm.weight" in sd:
|
|
return self.gemma.load_sd(sd)
|
|
else:
|
|
return self.jina.load_sd(sd)
|
|
|
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
|
class NewBieTEModel_(NewBieTEModel):
|
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
|
if llama_quantization_metadata is not None:
|
|
model_options = model_options.copy()
|
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
|
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
|
return NewBieTEModel_
|