From 4c432c11ed6f83466b8ff02569872925753a3c44 Mon Sep 17 00:00:00 2001 From: woctordho Date: Sat, 20 Dec 2025 13:57:22 +0800 Subject: [PATCH] Implement Jina CLIP v2 and NewBie dual CLIP (#11415) * Implement Jina CLIP v2 * Support quantized Gemma in NewBie dual CLIP --- comfy/model_base.py | 2 +- comfy/model_detection.py | 3 +- comfy/sd.py | 20 +++ comfy/text_encoders/jina_clip_2.py | 219 +++++++++++++++++++++++++++++ comfy/text_encoders/newbie.py | 62 ++++++++ nodes.py | 4 +- 6 files changed, 306 insertions(+), 4 deletions(-) create mode 100644 comfy/text_encoders/jina_clip_2.py create mode 100644 comfy/text_encoders/newbie.py diff --git a/comfy/model_base.py b/comfy/model_base.py index 6b8a8454d..c4f3c0639 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1110,7 +1110,7 @@ class Lumina2(BaseModel): if 'num_tokens' not in out: out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) - clip_text_pooled = kwargs["pooled_output"] # Newbie + clip_text_pooled = kwargs.get("pooled_output", None) # NewBie if clip_text_pooled is not None: out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 84fd409fd..539e296ed 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -430,8 +430,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["rope_theta"] = 10000.0 dit_config["ffn_dim_multiplier"] = 4.0 ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None) - if ctd_weight is not None: + if ctd_weight is not None: # NewBie dit_config["clip_text_dim"] = ctd_weight.shape[0] + # NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI elif dit_config["dim"] == 3840: # Z image dit_config["n_heads"] = 30 dit_config["n_kv_heads"] = 30 diff --git a/comfy/sd.py b/comfy/sd.py index c2a9728f3..7de7dd9c6 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -55,6 +55,8 @@ import comfy.text_encoders.hunyuan_image import comfy.text_encoders.z_image import comfy.text_encoders.ovis import comfy.text_encoders.kandinsky5 +import comfy.text_encoders.jina_clip_2 +import comfy.text_encoders.newbie import comfy.model_patcher import comfy.lora @@ -1008,6 +1010,7 @@ class CLIPType(Enum): OVIS = 21 KANDINSKY5 = 22 KANDINSKY5_IMAGE = 23 + NEWBIE = 24 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -1038,6 +1041,7 @@ class TEModel(Enum): MISTRAL3_24B_PRUNED_FLUX2 = 15 QWEN3_4B = 16 QWEN3_2B = 17 + JINA_CLIP_2 = 18 def detect_te_model(sd): @@ -1047,6 +1051,8 @@ def detect_te_model(sd): return TEModel.CLIP_H if "text_model.encoder.layers.0.mlp.fc1.weight" in sd: return TEModel.CLIP_L + if "model.encoder.layers.0.mixer.Wqkv.weight" in sd: + return TEModel.JINA_CLIP_2 if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] if weight.shape[-1] == 4096: @@ -1207,6 +1213,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.QWEN3_2B: clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer + elif te_model == TEModel.JINA_CLIP_2: + clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper + clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper else: # clip_l if clip_type == CLIPType.SD3: @@ -1262,6 +1271,17 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.KANDINSKY5_IMAGE: clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage + elif clip_type == CLIPType.NEWBIE: + clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer + if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]: + clip_data_gemma = clip_data[0] + clip_data_jina = clip_data[1] + else: + clip_data_gemma = clip_data[1] + clip_data_jina = clip_data[0] + tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None) + tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None) else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/text_encoders/jina_clip_2.py b/comfy/text_encoders/jina_clip_2.py new file mode 100644 index 000000000..0cffb6d16 --- /dev/null +++ b/comfy/text_encoders/jina_clip_2.py @@ -0,0 +1,219 @@ +# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation: +# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py +# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py + +from dataclasses import dataclass + +import torch +from torch import nn as nn +from torch.nn import functional as F + +import comfy.model_management +import comfy.ops +from comfy import sd1_clip +from .spiece_tokenizer import SPieceTokenizer + +class JinaClip2Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + # The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192 + super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + +class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2") + +# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json +@dataclass +class XLMRobertaConfig: + vocab_size: int = 250002 + type_vocab_size: int = 1 + hidden_size: int = 1024 + num_hidden_layers: int = 24 + num_attention_heads: int = 16 + rotary_emb_base: float = 20000.0 + intermediate_size: int = 4096 + hidden_act: str = "gelu" + hidden_dropout_prob: float = 0.1 + attention_probs_dropout_prob: float = 0.1 + layer_norm_eps: float = 1e-05 + bos_token_id: int = 0 + eos_token_id: int = 2 + pad_token_id: int = 1 + +class XLMRobertaEmbeddings(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + embed_dim = config.hidden_size + self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype) + self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype) + + def forward(self, input_ids=None, embeddings=None): + if input_ids is not None and embeddings is None: + embeddings = self.word_embeddings(input_ids) + + if embeddings is not None: + token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = embeddings + token_type_embeddings + return embeddings + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, base, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype: + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + emb = torch.cat((freqs, freqs), dim=-1) + self._cos_cached = emb.cos().to(dtype) + self._sin_cached = emb.sin().to(dtype) + + def forward(self, q, k): + batch, seqlen, heads, head_dim = q.shape + self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype) + + cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim) + sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim) + + def rotate_half(x): + size = x.shape[-1] // 2 + x1, x2 = x[..., :size], x[..., size:] + return torch.cat((-x2, x1), dim=-1) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class MHA(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = embed_dim // config.num_attention_heads + + self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device) + self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype) + self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + + def forward(self, x, mask=None, optimized_attention=None): + qkv = self.Wqkv(x) + batch_size, seq_len, _ = qkv.shape + qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(2) + + q, k = self.rotary_emb(q, k) + + # NHD -> HND + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True) + return self.out_proj(out) + +class MLP(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype) + self.activation = F.gelu + self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype) + + def forward(self, x): + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + +class Block(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.mixer = MHA(config, device=device, dtype=dtype, ops=ops) + self.dropout1 = nn.Dropout(config.hidden_dropout_prob) + self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) + self.dropout2 = nn.Dropout(config.hidden_dropout_prob) + self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + + def forward(self, hidden_states, mask=None, optimized_attention=None): + mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention) + hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states) + mlp_out = self.mlp(hidden_states) + hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states) + return hidden_states + +class XLMRobertaEncoder(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask=None): + optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True) + for layer in self.layers: + hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention) + return hidden_states + +class XLMRobertaModel_(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops) + self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + self.emb_drop = nn.Dropout(config.hidden_dropout_prob) + self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops) + + def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): + x = self.embeddings(input_ids=input_ids, embeddings=embeds) + x = self.emb_ln(x) + x = self.emb_drop(x) + + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1])) + mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) + + sequence_output = self.encoder(x, attention_mask=mask) + + # Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py + pooled_output = None + if attention_mask is None: + pooled_output = sequence_output.mean(dim=1) + else: + attention_mask = attention_mask.to(sequence_output.dtype) + pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True) + + # Intermediate output is not yet implemented, use None for placeholder + return sequence_output, None, pooled_output + +class XLMRobertaModel(nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self.config = XLMRobertaConfig(**config_dict) + self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations) + self.num_layers = self.config.num_hidden_layers + + def get_input_embeddings(self): + return self.model.embeddings.word_embeddings + + def set_input_embeddings(self, embeddings): + self.model.embeddings.word_embeddings = embeddings + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + +class JinaClip2TextModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) + +class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options) diff --git a/comfy/text_encoders/newbie.py b/comfy/text_encoders/newbie.py new file mode 100644 index 000000000..31904462b --- /dev/null +++ b/comfy/text_encoders/newbie.py @@ -0,0 +1,62 @@ +import torch + +import comfy.model_management +import comfy.text_encoders.jina_clip_2 +import comfy.text_encoders.lumina2 + +class NewBieTokenizer: + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]}) + self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]}) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = {} + out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs) + out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs) + return out + + def untokenize(self, token_weight_pair): + raise NotImplementedError + + def state_dict(self): + return {} + +class NewBieTEModel(torch.nn.Module): + def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}): + super().__init__() + dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device) + self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options) + self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options) + self.dtypes = {dtype, dtype_gemma} + + def set_clip_options(self, options): + self.gemma.set_clip_options(options) + self.jina.set_clip_options(options) + + def reset_clip_options(self): + self.gemma.reset_clip_options() + self.jina.reset_clip_options() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs_gemma = token_weight_pairs["gemma"] + token_weight_pairs_jina = token_weight_pairs["jina"] + + gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma) + jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina) + + return gemma_out, jina_pooled, gemma_extra + + def load_sd(self, sd): + if "model.layers.0.self_attn.q_norm.weight" in sd: + return self.gemma.load_sd(sd) + else: + return self.jina.load_sd(sd) + +def te(dtype_llama=None, llama_quantization_metadata=None): + class NewBieTEModel_(NewBieTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return NewBieTEModel_ diff --git a/nodes.py b/nodes.py index b13ceb578..7d83ecb21 100644 --- a/nodes.py +++ b/nodes.py @@ -970,7 +970,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -980,7 +980,7 @@ class DualCLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small" + DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2" def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)