import torch import comfy.model_management import comfy.text_encoders.jina_clip_2 import comfy.text_encoders.lumina2 class NewBieTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]}) self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]}) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs) out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs) return out def untokenize(self, token_weight_pair): raise NotImplementedError def state_dict(self): return {} class NewBieTEModel(torch.nn.Module): def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}): super().__init__() dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device) self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options) self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options) self.dtypes = {dtype, dtype_gemma} def set_clip_options(self, options): self.gemma.set_clip_options(options) self.jina.set_clip_options(options) def reset_clip_options(self): self.gemma.reset_clip_options() self.jina.reset_clip_options() def encode_token_weights(self, token_weight_pairs): token_weight_pairs_gemma = token_weight_pairs["gemma"] token_weight_pairs_jina = token_weight_pairs["jina"] gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma) jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina) return gemma_out, jina_pooled, gemma_extra def load_sd(self, sd): if "model.layers.0.self_attn.q_norm.weight" in sd: return self.gemma.load_sd(sd) else: return self.jina.load_sd(sd) def te(dtype_llama=None, llama_quantization_metadata=None): class NewBieTEModel_(NewBieTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): if llama_quantization_metadata is not None: model_options = model_options.copy() model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options) return NewBieTEModel_