ComfyUI/comfy/text_encoders/newbie.py
doctorpangloss 71698f4099 tests pass
2025-12-26 15:20:45 -08:00

73 lines
3.0 KiB
Python

import torch
from ..model_management import pick_weight_dtype
from .jina_clip_2 import JinaClip2TextModel, JinaClip2Tokenizer
from .lumina2 import Gemma3_4BTokenizer, Gemma3_4BModel
class NewBieTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
self.gemma = Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
self.jina = JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
out = {}
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
raise NotImplementedError
def state_dict(self):
return {}
class NewBieTEModel(torch.nn.Module):
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options=None):
super().__init__()
if model_options is None:
model_options = {}
dtype_gemma = pick_weight_dtype(dtype_gemma, dtype, device)
self.gemma = Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
self.jina = JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
self.dtypes = {dtype, dtype_gemma}
def set_clip_options(self, options):
self.gemma.set_clip_options(options)
self.jina.set_clip_options(options)
def reset_clip_options(self):
self.gemma.reset_clip_options()
self.jina.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_gemma = token_weight_pairs["gemma"]
token_weight_pairs_jina = token_weight_pairs["jina"]
gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma)
jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina)
return gemma_out, jina_pooled, gemma_extra
def load_sd(self, sd):
if "model.layers.0.self_attn.q_norm.weight" in sd:
return self.gemma.load_sd(sd)
else:
return self.jina.load_sd(sd)
def te(dtype_llama=None, llama_quantization_metadata=None):
class NewBieTEModel_(NewBieTEModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return NewBieTEModel_