Add future-proof image handling and template stripping options to ideogram4 text encoder.

This commit is contained in:
silveroxides 2026-06-09 14:47:18 +02:00
parent f89999289a
commit 213c7d8914

View File

@ -5,6 +5,8 @@ Qwen3-VL (layers 0,3,...,33,35), giving a 4096*13 = 53248-dim feature per token.
"""
import os
import torch
import numbers
from transformers import Qwen2Tokenizer
@ -30,13 +32,39 @@ class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer):
name="qwen3vl_8b", tokenizer=Qwen3VLTokenizer)
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, **kwargs):
skip_template = False
if text.startswith('<|im_start|>'):
skip_template = True
if text.startswith('<|start_header_id|>'):
skip_template = True
if prevent_empty_text and text == '':
text = ' '
if skip_template:
llama_text = text
else:
llama_text = llama_template.format(text)
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
key_name = next(iter(tokens))
embed_count = 0
qwen_tokens = tokens[key_name]
for r in qwen_tokens:
for i in range(len(r)):
if r[i][0] == 151655:
if len(images) > embed_count:
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1
return tokens
# Qwen3-VL-8B = 5e6 (vs plain Qwen3-8B's 1e6)
@ -58,10 +86,54 @@ class Ideogram4TEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=Qwen3VL8BModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
def encode_token_weights(self, token_weight_pairs, template_end=0):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
b, n, seq, h = out.shape # (B, n_taps=13, seq, 4096) stacked in ascending layer order.
out = out.permute(0, 2, 3, 1).reshape(b, seq, h * n) # (B, seq, 4096*13). permute -> (B, seq, H, taps).
if template_end != 0:
tok_pairs = token_weight_pairs["qwen3vl_8b"][0]
suffix_start = -1
if template_end == -1:
for i in range(len(tok_pairs) - 2):
t0 = tok_pairs[i][0]
t1 = tok_pairs[i + 1][0]
t2 = tok_pairs[i + 2][0]
if not torch.is_tensor(t0) and isinstance(t0, numbers.Integral) and \
not torch.is_tensor(t1) and isinstance(t1, numbers.Integral) and \
not torch.is_tensor(t2) and isinstance(t2, numbers.Integral):
if t0 == 151644 and t1 == 872 and t2 == 198:
template_end = i + 3
break
if template_end == -1:
template_end = 0
# Scan backward for the <|im_end|> token 151643 to determine suffix_start
for i in range(len(tok_pairs) - 1, -1, -1):
t = tok_pairs[i][0]
if not torch.is_tensor(t) and isinstance(t, numbers.Integral):
if t == 151643:
suffix_start = i
break
# If template_end resolves to greater than 0:
if template_end > 0:
out = out[:, template_end:]
if "attention_mask" in extra and extra["attention_mask"] is not None:
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
# If suffix_start is located and we are doing stripping:
if suffix_start >= 0:
suffix_len = len(tok_pairs) - suffix_start
out = out[:, :-suffix_len]
if "attention_mask" in extra and extra["attention_mask"] is not None:
extra["attention_mask"] = extra["attention_mask"][:, :-suffix_len]
if "attention_mask" in extra and extra["attention_mask"] is not None:
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
extra.pop("attention_mask")
return out, pooled, extra