From 213c7d8914ef37822feda3c2951a90db47a498b8 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Tue, 9 Jun 2026 14:47:18 +0200 Subject: [PATCH] Add future-proof image handling and template stripping options to ideogram4 text encoder. --- comfy/text_encoders/ideogram4.py | 84 +++++++++++++++++++++++++++++--- 1 file changed, 78 insertions(+), 6 deletions(-) diff --git a/comfy/text_encoders/ideogram4.py b/comfy/text_encoders/ideogram4.py index 55e655d67..92d4d7901 100644 --- a/comfy/text_encoders/ideogram4.py +++ b/comfy/text_encoders/ideogram4.py @@ -5,6 +5,8 @@ Qwen3-VL (layers 0,3,...,33,35), giving a 4096*13 = 53248-dim feature per token. """ import os +import torch +import numbers from transformers import Qwen2Tokenizer @@ -30,13 +32,39 @@ class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer): name="qwen3vl_8b", tokenizer=Qwen3VLTokenizer) self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" - def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): - if llama_template is None: - llama_text = self.llama_template.format(text) + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, **kwargs): + skip_template = False + if text.startswith('<|im_start|>'): + skip_template = True + if text.startswith('<|start_header_id|>'): + skip_template = True + if prevent_empty_text and text == '': + text = ' ' + + if skip_template: + llama_text = text else: - llama_text = llama_template.format(text) - return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + if llama_template is None: + if len(images) > 0: + llama_text = self.llama_template_images.format(text) + else: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + key_name = next(iter(tokens)) + embed_count = 0 + qwen_tokens = tokens[key_name] + for r in qwen_tokens: + for i in range(len(r)): + if r[i][0] == 151655: + if len(images) > embed_count: + r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:] + embed_count += 1 + return tokens # Qwen3-VL-8B = 5e6 (vs plain Qwen3-8B's 1e6) @@ -58,10 +86,54 @@ class Ideogram4TEModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=Qwen3VL8BModel, model_options=model_options) - def encode_token_weights(self, token_weight_pairs): + def encode_token_weights(self, token_weight_pairs, template_end=0): out, pooled, extra = super().encode_token_weights(token_weight_pairs) b, n, seq, h = out.shape # (B, n_taps=13, seq, 4096) stacked in ascending layer order. out = out.permute(0, 2, 3, 1).reshape(b, seq, h * n) # (B, seq, 4096*13). permute -> (B, seq, H, taps). + + if template_end != 0: + tok_pairs = token_weight_pairs["qwen3vl_8b"][0] + suffix_start = -1 + + if template_end == -1: + for i in range(len(tok_pairs) - 2): + t0 = tok_pairs[i][0] + t1 = tok_pairs[i + 1][0] + t2 = tok_pairs[i + 2][0] + if not torch.is_tensor(t0) and isinstance(t0, numbers.Integral) and \ + not torch.is_tensor(t1) and isinstance(t1, numbers.Integral) and \ + not torch.is_tensor(t2) and isinstance(t2, numbers.Integral): + if t0 == 151644 and t1 == 872 and t2 == 198: + template_end = i + 3 + break + if template_end == -1: + template_end = 0 + + # Scan backward for the <|im_end|> token 151643 to determine suffix_start + for i in range(len(tok_pairs) - 1, -1, -1): + t = tok_pairs[i][0] + if not torch.is_tensor(t) and isinstance(t, numbers.Integral): + if t == 151643: + suffix_start = i + break + + # If template_end resolves to greater than 0: + if template_end > 0: + out = out[:, template_end:] + if "attention_mask" in extra and extra["attention_mask"] is not None: + extra["attention_mask"] = extra["attention_mask"][:, template_end:] + + # If suffix_start is located and we are doing stripping: + if suffix_start >= 0: + suffix_len = len(tok_pairs) - suffix_start + out = out[:, :-suffix_len] + if "attention_mask" in extra and extra["attention_mask"] is not None: + extra["attention_mask"] = extra["attention_mask"][:, :-suffix_len] + + if "attention_mask" in extra and extra["attention_mask"] is not None: + if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]): + extra.pop("attention_mask") + return out, pooled, extra