mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-24 00:39:30 +08:00
Add future-proof image handling and template stripping options to ideogram4 text encoder.
This commit is contained in:
parent
f89999289a
commit
213c7d8914
@ -5,6 +5,8 @@ Qwen3-VL (layers 0,3,...,33,35), giving a 4096*13 = 53248-dim feature per token.
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import numbers
|
||||
|
||||
from transformers import Qwen2Tokenizer
|
||||
|
||||
@ -30,13 +32,39 @@ class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
name="qwen3vl_8b", tokenizer=Qwen3VLTokenizer)
|
||||
|
||||
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, **kwargs):
|
||||
skip_template = False
|
||||
if text.startswith('<|im_start|>'):
|
||||
skip_template = True
|
||||
if text.startswith('<|start_header_id|>'):
|
||||
skip_template = True
|
||||
if prevent_empty_text and text == '':
|
||||
text = ' '
|
||||
|
||||
if skip_template:
|
||||
llama_text = text
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
if llama_template is None:
|
||||
if len(images) > 0:
|
||||
llama_text = self.llama_template_images.format(text)
|
||||
else:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
key_name = next(iter(tokens))
|
||||
embed_count = 0
|
||||
qwen_tokens = tokens[key_name]
|
||||
for r in qwen_tokens:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 151655:
|
||||
if len(images) > embed_count:
|
||||
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
return tokens
|
||||
|
||||
|
||||
# Qwen3-VL-8B = 5e6 (vs plain Qwen3-8B's 1e6)
|
||||
@ -58,10 +86,54 @@ class Ideogram4TEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=Qwen3VL8BModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
def encode_token_weights(self, token_weight_pairs, template_end=0):
|
||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||
b, n, seq, h = out.shape # (B, n_taps=13, seq, 4096) stacked in ascending layer order.
|
||||
out = out.permute(0, 2, 3, 1).reshape(b, seq, h * n) # (B, seq, 4096*13). permute -> (B, seq, H, taps).
|
||||
|
||||
if template_end != 0:
|
||||
tok_pairs = token_weight_pairs["qwen3vl_8b"][0]
|
||||
suffix_start = -1
|
||||
|
||||
if template_end == -1:
|
||||
for i in range(len(tok_pairs) - 2):
|
||||
t0 = tok_pairs[i][0]
|
||||
t1 = tok_pairs[i + 1][0]
|
||||
t2 = tok_pairs[i + 2][0]
|
||||
if not torch.is_tensor(t0) and isinstance(t0, numbers.Integral) and \
|
||||
not torch.is_tensor(t1) and isinstance(t1, numbers.Integral) and \
|
||||
not torch.is_tensor(t2) and isinstance(t2, numbers.Integral):
|
||||
if t0 == 151644 and t1 == 872 and t2 == 198:
|
||||
template_end = i + 3
|
||||
break
|
||||
if template_end == -1:
|
||||
template_end = 0
|
||||
|
||||
# Scan backward for the <|im_end|> token 151643 to determine suffix_start
|
||||
for i in range(len(tok_pairs) - 1, -1, -1):
|
||||
t = tok_pairs[i][0]
|
||||
if not torch.is_tensor(t) and isinstance(t, numbers.Integral):
|
||||
if t == 151643:
|
||||
suffix_start = i
|
||||
break
|
||||
|
||||
# If template_end resolves to greater than 0:
|
||||
if template_end > 0:
|
||||
out = out[:, template_end:]
|
||||
if "attention_mask" in extra and extra["attention_mask"] is not None:
|
||||
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
|
||||
|
||||
# If suffix_start is located and we are doing stripping:
|
||||
if suffix_start >= 0:
|
||||
suffix_len = len(tok_pairs) - suffix_start
|
||||
out = out[:, :-suffix_len]
|
||||
if "attention_mask" in extra and extra["attention_mask"] is not None:
|
||||
extra["attention_mask"] = extra["attention_mask"][:, :-suffix_len]
|
||||
|
||||
if "attention_mask" in extra and extra["attention_mask"] is not None:
|
||||
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
|
||||
extra.pop("attention_mask")
|
||||
|
||||
return out, pooled, extra
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user