mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-25 09:19:46 +08:00
Add future-proof image handling and template stripping options to ideogram4 text encoder.
This commit is contained in:
parent
f89999289a
commit
213c7d8914
@ -5,6 +5,8 @@ Qwen3-VL (layers 0,3,...,33,35), giving a 4096*13 = 53248-dim feature per token.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import torch
|
||||||
|
import numbers
|
||||||
|
|
||||||
from transformers import Qwen2Tokenizer
|
from transformers import Qwen2Tokenizer
|
||||||
|
|
||||||
@ -30,13 +32,39 @@ class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
name="qwen3vl_8b", tokenizer=Qwen3VLTokenizer)
|
name="qwen3vl_8b", tokenizer=Qwen3VLTokenizer)
|
||||||
|
|
||||||
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
|
||||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, **kwargs):
|
||||||
if llama_template is None:
|
skip_template = False
|
||||||
llama_text = self.llama_template.format(text)
|
if text.startswith('<|im_start|>'):
|
||||||
|
skip_template = True
|
||||||
|
if text.startswith('<|start_header_id|>'):
|
||||||
|
skip_template = True
|
||||||
|
if prevent_empty_text and text == '':
|
||||||
|
text = ' '
|
||||||
|
|
||||||
|
if skip_template:
|
||||||
|
llama_text = text
|
||||||
else:
|
else:
|
||||||
llama_text = llama_template.format(text)
|
if llama_template is None:
|
||||||
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
if len(images) > 0:
|
||||||
|
llama_text = self.llama_template_images.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = self.llama_template.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = llama_template.format(text)
|
||||||
|
|
||||||
|
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||||
|
key_name = next(iter(tokens))
|
||||||
|
embed_count = 0
|
||||||
|
qwen_tokens = tokens[key_name]
|
||||||
|
for r in qwen_tokens:
|
||||||
|
for i in range(len(r)):
|
||||||
|
if r[i][0] == 151655:
|
||||||
|
if len(images) > embed_count:
|
||||||
|
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
||||||
|
embed_count += 1
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
# Qwen3-VL-8B = 5e6 (vs plain Qwen3-8B's 1e6)
|
# Qwen3-VL-8B = 5e6 (vs plain Qwen3-8B's 1e6)
|
||||||
@ -58,10 +86,54 @@ class Ideogram4TEModel(sd1_clip.SD1ClipModel):
|
|||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=Qwen3VL8BModel, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=Qwen3VL8BModel, model_options=model_options)
|
||||||
|
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
def encode_token_weights(self, token_weight_pairs, template_end=0):
|
||||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||||
b, n, seq, h = out.shape # (B, n_taps=13, seq, 4096) stacked in ascending layer order.
|
b, n, seq, h = out.shape # (B, n_taps=13, seq, 4096) stacked in ascending layer order.
|
||||||
out = out.permute(0, 2, 3, 1).reshape(b, seq, h * n) # (B, seq, 4096*13). permute -> (B, seq, H, taps).
|
out = out.permute(0, 2, 3, 1).reshape(b, seq, h * n) # (B, seq, 4096*13). permute -> (B, seq, H, taps).
|
||||||
|
|
||||||
|
if template_end != 0:
|
||||||
|
tok_pairs = token_weight_pairs["qwen3vl_8b"][0]
|
||||||
|
suffix_start = -1
|
||||||
|
|
||||||
|
if template_end == -1:
|
||||||
|
for i in range(len(tok_pairs) - 2):
|
||||||
|
t0 = tok_pairs[i][0]
|
||||||
|
t1 = tok_pairs[i + 1][0]
|
||||||
|
t2 = tok_pairs[i + 2][0]
|
||||||
|
if not torch.is_tensor(t0) and isinstance(t0, numbers.Integral) and \
|
||||||
|
not torch.is_tensor(t1) and isinstance(t1, numbers.Integral) and \
|
||||||
|
not torch.is_tensor(t2) and isinstance(t2, numbers.Integral):
|
||||||
|
if t0 == 151644 and t1 == 872 and t2 == 198:
|
||||||
|
template_end = i + 3
|
||||||
|
break
|
||||||
|
if template_end == -1:
|
||||||
|
template_end = 0
|
||||||
|
|
||||||
|
# Scan backward for the <|im_end|> token 151643 to determine suffix_start
|
||||||
|
for i in range(len(tok_pairs) - 1, -1, -1):
|
||||||
|
t = tok_pairs[i][0]
|
||||||
|
if not torch.is_tensor(t) and isinstance(t, numbers.Integral):
|
||||||
|
if t == 151643:
|
||||||
|
suffix_start = i
|
||||||
|
break
|
||||||
|
|
||||||
|
# If template_end resolves to greater than 0:
|
||||||
|
if template_end > 0:
|
||||||
|
out = out[:, template_end:]
|
||||||
|
if "attention_mask" in extra and extra["attention_mask"] is not None:
|
||||||
|
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
|
||||||
|
|
||||||
|
# If suffix_start is located and we are doing stripping:
|
||||||
|
if suffix_start >= 0:
|
||||||
|
suffix_len = len(tok_pairs) - suffix_start
|
||||||
|
out = out[:, :-suffix_len]
|
||||||
|
if "attention_mask" in extra and extra["attention_mask"] is not None:
|
||||||
|
extra["attention_mask"] = extra["attention_mask"][:, :-suffix_len]
|
||||||
|
|
||||||
|
if "attention_mask" in extra and extra["attention_mask"] is not None:
|
||||||
|
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
|
||||||
|
extra.pop("attention_mask")
|
||||||
|
|
||||||
return out, pooled, extra
|
return out, pooled, extra
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user