mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 04:50:49 +08:00
220 lines
11 KiB
Python
220 lines
11 KiB
Python
# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation:
|
|
# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py
|
|
# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py
|
|
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
from torch import nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
import comfy.model_management
|
|
import comfy.ops
|
|
from comfy import sd1_clip
|
|
from .spiece_tokenizer import SPieceTokenizer
|
|
|
|
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
|
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
|
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
|
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
|
|
|
|
def state_dict(self):
|
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
|
|
|
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
|
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
|
|
|
|
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
|
|
@dataclass
|
|
class XLMRobertaConfig:
|
|
vocab_size: int = 250002
|
|
type_vocab_size: int = 1
|
|
hidden_size: int = 1024
|
|
num_hidden_layers: int = 24
|
|
num_attention_heads: int = 16
|
|
rotary_emb_base: float = 20000.0
|
|
intermediate_size: int = 4096
|
|
hidden_act: str = "gelu"
|
|
hidden_dropout_prob: float = 0.1
|
|
attention_probs_dropout_prob: float = 0.1
|
|
layer_norm_eps: float = 1e-05
|
|
bos_token_id: int = 0
|
|
eos_token_id: int = 2
|
|
pad_token_id: int = 1
|
|
|
|
class XLMRobertaEmbeddings(nn.Module):
|
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
|
super().__init__()
|
|
embed_dim = config.hidden_size
|
|
self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype)
|
|
self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype)
|
|
|
|
def forward(self, input_ids=None, embeddings=None):
|
|
if input_ids is not None and embeddings is None:
|
|
embeddings = self.word_embeddings(input_ids)
|
|
|
|
if embeddings is not None:
|
|
token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32)
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
embeddings = embeddings + token_type_embeddings
|
|
return embeddings
|
|
|
|
class RotaryEmbedding(nn.Module):
|
|
def __init__(self, dim, base, device=None):
|
|
super().__init__()
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
self._seq_len_cached = 0
|
|
self._cos_cached = None
|
|
self._sin_cached = None
|
|
|
|
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
|
if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
|
|
self._seq_len_cached = seqlen
|
|
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
self._cos_cached = emb.cos().to(dtype)
|
|
self._sin_cached = emb.sin().to(dtype)
|
|
|
|
def forward(self, q, k):
|
|
batch, seqlen, heads, head_dim = q.shape
|
|
self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype)
|
|
|
|
cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim)
|
|
sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim)
|
|
|
|
def rotate_half(x):
|
|
size = x.shape[-1] // 2
|
|
x1, x2 = x[..., :size], x[..., size:]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
return q_embed, k_embed
|
|
|
|
class MHA(nn.Module):
|
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
|
super().__init__()
|
|
embed_dim = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = embed_dim // config.num_attention_heads
|
|
|
|
self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device)
|
|
self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
|
|
self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
|
|
|
def forward(self, x, mask=None, optimized_attention=None):
|
|
qkv = self.Wqkv(x)
|
|
batch_size, seq_len, _ = qkv.shape
|
|
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
|
q, k, v = qkv.unbind(2)
|
|
|
|
q, k = self.rotary_emb(q, k)
|
|
|
|
# NHD -> HND
|
|
q = q.transpose(1, 2)
|
|
k = k.transpose(1, 2)
|
|
v = v.transpose(1, 2)
|
|
|
|
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
|
|
return self.out_proj(out)
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
|
super().__init__()
|
|
self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
|
|
self.activation = F.gelu
|
|
self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.activation(x)
|
|
x = self.fc2(x)
|
|
return x
|
|
|
|
class Block(nn.Module):
|
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
|
super().__init__()
|
|
self.mixer = MHA(config, device=device, dtype=dtype, ops=ops)
|
|
self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
|
|
self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
|
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
|
self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
|
|
self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
|
|
|
def forward(self, hidden_states, mask=None, optimized_attention=None):
|
|
mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention)
|
|
hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states)
|
|
mlp_out = self.mlp(hidden_states)
|
|
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
|
|
return hidden_states
|
|
|
|
class XLMRobertaEncoder(nn.Module):
|
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
|
super().__init__()
|
|
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
|
|
|
|
def forward(self, hidden_states, attention_mask=None):
|
|
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
|
|
for layer in self.layers:
|
|
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
|
|
return hidden_states
|
|
|
|
class XLMRobertaModel_(nn.Module):
|
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
|
super().__init__()
|
|
self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops)
|
|
self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
|
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
|
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
|
|
|
|
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
|
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
|
|
x = self.emb_ln(x)
|
|
x = self.emb_drop(x)
|
|
|
|
mask = None
|
|
if attention_mask is not None:
|
|
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1]))
|
|
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
|
|
|
sequence_output = self.encoder(x, attention_mask=mask)
|
|
|
|
# Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py
|
|
pooled_output = None
|
|
if attention_mask is None:
|
|
pooled_output = sequence_output.mean(dim=1)
|
|
else:
|
|
attention_mask = attention_mask.to(sequence_output.dtype)
|
|
pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True)
|
|
|
|
# Intermediate output is not yet implemented, use None for placeholder
|
|
return sequence_output, None, pooled_output
|
|
|
|
class XLMRobertaModel(nn.Module):
|
|
def __init__(self, config_dict, dtype, device, operations):
|
|
super().__init__()
|
|
self.config = XLMRobertaConfig(**config_dict)
|
|
self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations)
|
|
self.num_layers = self.config.num_hidden_layers
|
|
|
|
def get_input_embeddings(self):
|
|
return self.model.embeddings.word_embeddings
|
|
|
|
def set_input_embeddings(self, embeddings):
|
|
self.model.embeddings.word_embeddings = embeddings
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.model(*args, **kwargs)
|
|
|
|
class JinaClip2TextModel(sd1_clip.SDClipModel):
|
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
|
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
|
|
|
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
|
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
|
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)
|