diff --git a/comfy/text_encoders/ideogram4.py b/comfy/text_encoders/ideogram4.py index 151b43c53..8173f677b 100644 --- a/comfy/text_encoders/ideogram4.py +++ b/comfy/text_encoders/ideogram4.py @@ -5,6 +5,8 @@ Qwen3-VL (layers 0,3,...,33,35), giving a 4096*13 = 53248-dim feature per token. """ import os +import torch +import numbers from transformers import Qwen2Tokenizer @@ -32,14 +34,23 @@ class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer): self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" - def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, prevent_empty_text=False, **kwargs): + skip_template = False if text.startswith('<|im_start|>'): + skip_template = True + if prevent_empty_text and text == '': + text = ' ' + + if skip_template: llama_text = text - elif llama_template is None: - llama_text = self.llama_template.format(text) else: - llama_text = llama_template.format(text) - return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens # Qwen3-VL-8B = 5e6 (vs plain Qwen3-8B's 1e6) @@ -61,10 +72,54 @@ class Ideogram4TEModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=Qwen3VL8BModel, model_options=model_options) - def encode_token_weights(self, token_weight_pairs): + def encode_token_weights(self, token_weight_pairs, template_end=0): out, pooled, extra = super().encode_token_weights(token_weight_pairs) b, n, seq, h = out.shape # (B, n_taps=13, seq, 4096) stacked in ascending layer order. out = out.permute(0, 2, 3, 1).reshape(b, seq, h * n) # (B, seq, 4096*13). permute -> (B, seq, H, taps). + + if template_end != 0: + tok_pairs = token_weight_pairs["qwen3vl_8b"][0] + suffix_start = -1 + + if template_end == -1: + for i in range(len(tok_pairs) - 2): + t0 = tok_pairs[i][0] + t1 = tok_pairs[i + 1][0] + t2 = tok_pairs[i + 2][0] + if not torch.is_tensor(t0) and isinstance(t0, numbers.Integral) and \ + not torch.is_tensor(t1) and isinstance(t1, numbers.Integral) and \ + not torch.is_tensor(t2) and isinstance(t2, numbers.Integral): + if t0 == 151644 and t1 == 872 and t2 == 198: + template_end = i + 3 + break + if template_end == -1: + template_end = 0 + + # Scan backward for the <|im_end|> token 151643 to determine suffix_start + for i in range(len(tok_pairs) - 1, -1, -1): + t = tok_pairs[i][0] + if not torch.is_tensor(t) and isinstance(t, numbers.Integral): + if t == 151643: + suffix_start = i + break + + # If template_end resolves to greater than 0: + if template_end > 0: + out = out[:, template_end:] + if "attention_mask" in extra and extra["attention_mask"] is not None: + extra["attention_mask"] = extra["attention_mask"][:, template_end:] + + # If suffix_start is located and we are doing stripping: + if suffix_start >= 0: + suffix_len = len(tok_pairs) - suffix_start + out = out[:, :-suffix_len] + if "attention_mask" in extra and extra["attention_mask"] is not None: + extra["attention_mask"] = extra["attention_mask"][:, :-suffix_len] + + if "attention_mask" in extra and extra["attention_mask"] is not None: + if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]): + extra.pop("attention_mask") + return out, pooled, extra