mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 13:00:54 +08:00
Implement Jina CLIP v2 and NewBie dual CLIP (#11415)
* Implement Jina CLIP v2 * Support quantized Gemma in NewBie dual CLIP
This commit is contained in:
parent
31e961736a
commit
4c432c11ed
@ -1110,7 +1110,7 @@ class Lumina2(BaseModel):
|
||||
if 'num_tokens' not in out:
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
||||
|
||||
clip_text_pooled = kwargs["pooled_output"] # Newbie
|
||||
clip_text_pooled = kwargs.get("pooled_output", None) # NewBie
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
|
||||
@ -430,8 +430,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
dit_config["ffn_dim_multiplier"] = 4.0
|
||||
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
||||
if ctd_weight is not None:
|
||||
if ctd_weight is not None: # NewBie
|
||||
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
||||
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
|
||||
elif dit_config["dim"] == 3840: # Z image
|
||||
dit_config["n_heads"] = 30
|
||||
dit_config["n_kv_heads"] = 30
|
||||
|
||||
20
comfy/sd.py
20
comfy/sd.py
@ -55,6 +55,8 @@ import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.ovis
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.jina_clip_2
|
||||
import comfy.text_encoders.newbie
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -1008,6 +1010,7 @@ class CLIPType(Enum):
|
||||
OVIS = 21
|
||||
KANDINSKY5 = 22
|
||||
KANDINSKY5_IMAGE = 23
|
||||
NEWBIE = 24
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@ -1038,6 +1041,7 @@ class TEModel(Enum):
|
||||
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
||||
QWEN3_4B = 16
|
||||
QWEN3_2B = 17
|
||||
JINA_CLIP_2 = 18
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1047,6 +1051,8 @@ def detect_te_model(sd):
|
||||
return TEModel.CLIP_H
|
||||
if "text_model.encoder.layers.0.mlp.fc1.weight" in sd:
|
||||
return TEModel.CLIP_L
|
||||
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
|
||||
return TEModel.JINA_CLIP_2
|
||||
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
if weight.shape[-1] == 4096:
|
||||
@ -1207,6 +1213,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif te_model == TEModel.QWEN3_2B:
|
||||
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
||||
elif te_model == TEModel.JINA_CLIP_2:
|
||||
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
@ -1262,6 +1271,17 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||
elif clip_type == CLIPType.NEWBIE:
|
||||
clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer
|
||||
if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]:
|
||||
clip_data_gemma = clip_data[0]
|
||||
clip_data_jina = clip_data[1]
|
||||
else:
|
||||
clip_data_gemma = clip_data[1]
|
||||
clip_data_jina = clip_data[0]
|
||||
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
|
||||
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
|
||||
219
comfy/text_encoders/jina_clip_2.py
Normal file
219
comfy/text_encoders/jina_clip_2.py
Normal file
@ -0,0 +1,219 @@
|
||||
# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation:
|
||||
# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py
|
||||
# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
from comfy import sd1_clip
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
|
||||
class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
# The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")
|
||||
|
||||
# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
|
||||
@dataclass
|
||||
class XLMRobertaConfig:
|
||||
vocab_size: int = 250002
|
||||
type_vocab_size: int = 1
|
||||
hidden_size: int = 1024
|
||||
num_hidden_layers: int = 24
|
||||
num_attention_heads: int = 16
|
||||
rotary_emb_base: float = 20000.0
|
||||
intermediate_size: int = 4096
|
||||
hidden_act: str = "gelu"
|
||||
hidden_dropout_prob: float = 0.1
|
||||
attention_probs_dropout_prob: float = 0.1
|
||||
layer_norm_eps: float = 1e-05
|
||||
bos_token_id: int = 0
|
||||
eos_token_id: int = 2
|
||||
pad_token_id: int = 1
|
||||
|
||||
class XLMRobertaEmbeddings(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype)
|
||||
self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, input_ids=None, embeddings=None):
|
||||
if input_ids is not None and embeddings is None:
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
|
||||
if embeddings is not None:
|
||||
token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
embeddings = embeddings + token_type_embeddings
|
||||
return embeddings
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, base, device=None):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
|
||||
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
||||
if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self._cos_cached = emb.cos().to(dtype)
|
||||
self._sin_cached = emb.sin().to(dtype)
|
||||
|
||||
def forward(self, q, k):
|
||||
batch, seqlen, heads, head_dim = q.shape
|
||||
self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype)
|
||||
|
||||
cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim)
|
||||
sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim)
|
||||
|
||||
def rotate_half(x):
|
||||
size = x.shape[-1] // 2
|
||||
x1, x2 = x[..., :size], x[..., size:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
class MHA(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = embed_dim // config.num_attention_heads
|
||||
|
||||
self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device)
|
||||
self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
|
||||
self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, mask=None, optimized_attention=None):
|
||||
qkv = self.Wqkv(x)
|
||||
batch_size, seq_len, _ = qkv.shape
|
||||
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||
q, k, v = qkv.unbind(2)
|
||||
|
||||
q, k = self.rotary_emb(q, k)
|
||||
|
||||
# NHD -> HND
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
|
||||
return self.out_proj(out)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
|
||||
self.activation = F.gelu
|
||||
self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.activation(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.mixer = MHA(config, device=device, dtype=dtype, ops=ops)
|
||||
self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states, mask=None, optimized_attention=None):
|
||||
mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention)
|
||||
hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states)
|
||||
mlp_out = self.mlp(hidden_states)
|
||||
hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
|
||||
return hidden_states
|
||||
|
||||
class XLMRobertaEncoder(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None):
|
||||
optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
|
||||
return hidden_states
|
||||
|
||||
class XLMRobertaModel_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops)
|
||||
self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
||||
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||
x = self.embeddings(input_ids=input_ids, embeddings=embeds)
|
||||
x = self.emb_ln(x)
|
||||
x = self.emb_drop(x)
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1]))
|
||||
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
||||
|
||||
sequence_output = self.encoder(x, attention_mask=mask)
|
||||
|
||||
# Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py
|
||||
pooled_output = None
|
||||
if attention_mask is None:
|
||||
pooled_output = sequence_output.mean(dim=1)
|
||||
else:
|
||||
attention_mask = attention_mask.to(sequence_output.dtype)
|
||||
pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Intermediate output is not yet implemented, use None for placeholder
|
||||
return sequence_output, None, pooled_output
|
||||
|
||||
class XLMRobertaModel(nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.config = XLMRobertaConfig(**config_dict)
|
||||
self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations)
|
||||
self.num_layers = self.config.num_hidden_layers
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.model.embeddings.word_embeddings = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
class JinaClip2TextModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||
|
||||
class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)
|
||||
62
comfy/text_encoders/newbie.py
Normal file
62
comfy/text_encoders/newbie.py
Normal file
@ -0,0 +1,62 @@
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.text_encoders.jina_clip_2
|
||||
import comfy.text_encoders.lumina2
|
||||
|
||||
class NewBieTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]})
|
||||
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]})
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
raise NotImplementedError
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
class NewBieTEModel(torch.nn.Module):
|
||||
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device)
|
||||
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
|
||||
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes = {dtype, dtype_gemma}
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.gemma.set_clip_options(options)
|
||||
self.jina.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.gemma.reset_clip_options()
|
||||
self.jina.reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_gemma = token_weight_pairs["gemma"]
|
||||
token_weight_pairs_jina = token_weight_pairs["jina"]
|
||||
|
||||
gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma)
|
||||
jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina)
|
||||
|
||||
return gemma_out, jina_pooled, gemma_extra
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "model.layers.0.self_attn.q_norm.weight" in sd:
|
||||
return self.gemma.load_sd(sd)
|
||||
else:
|
||||
return self.jina.load_sd(sd)
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class NewBieTEModel_(NewBieTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
return NewBieTEModel_
|
||||
4
nodes.py
4
nodes.py
@ -970,7 +970,7 @@ class DualCLIPLoader:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@ -980,7 +980,7 @@ class DualCLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small"
|
||||
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2"
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user