From d49420b3c7daf86cae1d7419e37848a974e1b7be Mon Sep 17 00:00:00 2001 From: Talmaj Date: Sun, 22 Mar 2026 04:51:05 +0100 Subject: [PATCH] LongCat-Image edit (#13003) --- comfy/ldm/flux/model.py | 2 +- comfy/model_base.py | 5 +++-- comfy/text_encoders/llama.py | 11 +++++++++-- comfy/text_encoders/longcat_image.py | 25 ++++++++++++++++++++----- comfy/text_encoders/qwen_vl.py | 3 +++ 5 files changed, 36 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 8e7912e6d..2020326c2 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -386,7 +386,7 @@ class Flux(nn.Module): h = max(h, ref.shape[-2] + h_offset) w = max(w, ref.shape[-1] + w_offset) - kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, transformer_options=transformer_options) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) ref_num_tokens.append(kontext.shape[1]) diff --git a/comfy/model_base.py b/comfy/model_base.py index 43ec93324..bfffe2402 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -937,9 +937,10 @@ class LongCatImage(Flux): transformer_options = transformer_options.copy() rope_opts = transformer_options.get("rope_options", {}) rope_opts = dict(rope_opts) + pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0 rope_opts.setdefault("shift_t", 1.0) - rope_opts.setdefault("shift_y", 512.0) - rope_opts.setdefault("shift_x", 512.0) + rope_opts.setdefault("shift_y", pe_len) + rope_opts.setdefault("shift_x", pe_len) transformer_options["rope_options"] = rope_opts return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index ccc200b7a..9fdea999c 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -1028,12 +1028,19 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module): grid = e.get("extra", None) start = e.get("index") if position_ids is None: - position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) + position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long) position_ids[:, :start] = torch.arange(0, start, device=embeds.device) end = e.get("size") + start len_max = int(grid.max()) // 2 start_next = len_max + start - position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device) + if attention_mask is not None: + # Assign compact sequential positions to attended tokens only, + # skipping over padding so post-padding tokens aren't inflated. + after_mask = attention_mask[0, end:] + text_positions = after_mask.cumsum(0) - 1 + start_next + offset + position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:]) + else: + position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device) position_ids[0, start:end] = start + offset max_d = int(grid[0][1]) // 2 position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] diff --git a/comfy/text_encoders/longcat_image.py b/comfy/text_encoders/longcat_image.py index 882d80901..0962779e3 100644 --- a/comfy/text_encoders/longcat_image.py +++ b/comfy/text_encoders/longcat_image.py @@ -64,7 +64,13 @@ class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer): return [output] +IMAGE_PAD_TOKEN_ID = 151655 + class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): + T2I_PREFIX = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n" + EDIT_PREFIX = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + SUFFIX = "<|im_end|>\n<|im_start|>assistant\n" + def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__( embedding_directory=embedding_directory, @@ -72,10 +78,8 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): name="qwen25_7b", tokenizer=LongCatImageBaseTokenizer, ) - self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n" - self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n" - def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, images=None, **kwargs): skip_template = False if text.startswith("<|im_start|>"): skip_template = True @@ -90,11 +94,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): text, return_word_ids=return_word_ids, disable_weights=True, **kwargs ) else: + has_images = images is not None and len(images) > 0 + template_prefix = self.EDIT_PREFIX if has_images else self.T2I_PREFIX + prefix_ids = base_tok.tokenizer( - self.longcat_template_prefix, add_special_tokens=False + template_prefix, add_special_tokens=False )["input_ids"] suffix_ids = base_tok.tokenizer( - self.longcat_template_suffix, add_special_tokens=False + self.SUFFIX, add_special_tokens=False )["input_ids"] prompt_tokens = base_tok.tokenize_with_weights( @@ -106,6 +113,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): suffix_pairs = [(t, 1.0) for t in suffix_ids] combined = prefix_pairs + prompt_pairs + suffix_pairs + + if has_images: + embed_count = 0 + for i in range(len(combined)): + if combined[i][0] == IMAGE_PAD_TOKEN_ID and embed_count < len(images): + combined[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"}, combined[i][1]) + embed_count += 1 + tokens = {"qwen25_7b": [combined]} return tokens diff --git a/comfy/text_encoders/qwen_vl.py b/comfy/text_encoders/qwen_vl.py index 3b18ce730..98c350a12 100644 --- a/comfy/text_encoders/qwen_vl.py +++ b/comfy/text_encoders/qwen_vl.py @@ -425,4 +425,7 @@ class Qwen2VLVisionTransformer(nn.Module): hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention) hidden_states = self.merger(hidden_states) + # Potentially important for spatially precise edits. This is present in the HF implementation. + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] return hidden_states