import re import numbers import torch from comfy import sd1_clip from comfy.text_encoders.qwen_image import Qwen25_7BVLITokenizer, Qwen25_7BVLIModel import logging logger = logging.getLogger(__name__) QUOTE_PAIRS = [("'", "'"), ('"', '"'), ("\u2018", "\u2019"), ("\u201c", "\u201d")] QUOTE_PATTERN = "|".join( [ re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) for q1, q2 in QUOTE_PAIRS ] ) WORD_INTERNAL_QUOTE_RE = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") def split_quotation(prompt): matches = WORD_INTERNAL_QUOTE_RE.findall(prompt) mapping = [] for i, word_src in enumerate(set(matches)): word_tgt = "longcat_$##$_longcat" * (i + 1) prompt = prompt.replace(word_src, word_tgt) mapping.append((word_src, word_tgt)) parts = re.split(f"({QUOTE_PATTERN})", prompt) result = [] for part in parts: for word_src, word_tgt in mapping: part = part.replace(word_tgt, word_src) if not part: continue is_quoted = bool(re.match(QUOTE_PATTERN, part)) result.append((part, is_quoted)) return result class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.max_length = 512 def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): parts = split_quotation(text) all_tokens = [] for part_text, is_quoted in parts: if is_quoted: for char in part_text: ids = self.tokenizer(char, add_special_tokens=False)["input_ids"] all_tokens.extend(ids) else: ids = self.tokenizer(part_text, add_special_tokens=False)["input_ids"] all_tokens.extend(ids) if len(all_tokens) > self.max_length: all_tokens = all_tokens[: self.max_length] logger.warning(f"Truncated prompt to {self.max_length} tokens") output = [(t, 1.0) for t in all_tokens] # Pad to max length self.pad_tokens(output, self.max_length - len(output)) return [output] IMAGE_PAD_TOKEN_ID = 151655 class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): T2I_PREFIX = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n" EDIT_PREFIX = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" SUFFIX = "<|im_end|>\n<|im_start|>assistant\n" def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__( embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=LongCatImageBaseTokenizer, ) def tokenize_with_weights(self, text, return_word_ids=False, images=None, **kwargs): skip_template = False if text.startswith("<|im_start|>"): skip_template = True if text.startswith("<|start_header_id|>"): skip_template = True if text == "": text = " " base_tok = getattr(self, "qwen25_7b") if skip_template: tokens = super().tokenize_with_weights( text, return_word_ids=return_word_ids, disable_weights=True, **kwargs ) else: has_images = images is not None and len(images) > 0 template_prefix = self.EDIT_PREFIX if has_images else self.T2I_PREFIX prefix_ids = base_tok.tokenizer( template_prefix, add_special_tokens=False )["input_ids"] suffix_ids = base_tok.tokenizer( self.SUFFIX, add_special_tokens=False )["input_ids"] prompt_tokens = base_tok.tokenize_with_weights( text, return_word_ids=return_word_ids, **kwargs ) prompt_pairs = prompt_tokens[0] prefix_pairs = [(t, 1.0) for t in prefix_ids] suffix_pairs = [(t, 1.0) for t in suffix_ids] combined = prefix_pairs + prompt_pairs + suffix_pairs if has_images: embed_count = 0 for i in range(len(combined)): if combined[i][0] == IMAGE_PAD_TOKEN_ID and embed_count < len(images): combined[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"}, combined[i][1]) embed_count += 1 tokens = {"qwen25_7b": [combined]} return tokens class LongCatImageTEModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__( device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options, ) def encode_token_weights(self, token_weight_pairs, template_end=-1): out, pooled, extra = super().encode_token_weights(token_weight_pairs) tok_pairs = token_weight_pairs["qwen25_7b"][0] count_im_start = 0 if template_end == -1: for i, v in enumerate(tok_pairs): elem = v[0] if not torch.is_tensor(elem): if isinstance(elem, numbers.Integral): if elem == 151644 and count_im_start < 2: template_end = i count_im_start += 1 if out.shape[1] > (template_end + 3): if tok_pairs[template_end + 1][0] == 872: if tok_pairs[template_end + 2][0] == 198: template_end += 3 if template_end == -1: template_end = 0 suffix_start = None for i in range(len(tok_pairs) - 1, -1, -1): elem = tok_pairs[i][0] if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral): if elem == 151645: suffix_start = i break out = out[:, template_end:] if "attention_mask" in extra: extra["attention_mask"] = extra["attention_mask"][:, template_end:] if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]): extra.pop("attention_mask") if suffix_start is not None: suffix_len = len(tok_pairs) - suffix_start if suffix_len > 0 and out.shape[1] > suffix_len: out = out[:, :-suffix_len] if "attention_mask" in extra: extra["attention_mask"] = extra["attention_mask"][:, :-suffix_len] if extra["attention_mask"].sum() == torch.numel( extra["attention_mask"] ): extra.pop("attention_mask") return out, pooled, extra def te(dtype_llama=None, llama_quantization_metadata=None): class LongCatImageTEModel_(LongCatImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): if llama_quantization_metadata is not None: model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, model_options=model_options) return LongCatImageTEModel_