LongCat-Image edit (#13003)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Generate Pydantic Stubs from api.comfy.org / generate-models (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run

This commit is contained in:
Talmaj 2026-03-22 04:51:05 +01:00 committed by GitHub
parent ebf6b52e32
commit d49420b3c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 36 additions and 10 deletions

View File

@ -386,7 +386,7 @@ class Flux(nn.Module):
h = max(h, ref.shape[-2] + h_offset) h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset) w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, transformer_options=transformer_options)
img = torch.cat([img, kontext], dim=1) img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1]) ref_num_tokens.append(kontext.shape[1])

View File

@ -937,9 +937,10 @@ class LongCatImage(Flux):
transformer_options = transformer_options.copy() transformer_options = transformer_options.copy()
rope_opts = transformer_options.get("rope_options", {}) rope_opts = transformer_options.get("rope_options", {})
rope_opts = dict(rope_opts) rope_opts = dict(rope_opts)
pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0
rope_opts.setdefault("shift_t", 1.0) rope_opts.setdefault("shift_t", 1.0)
rope_opts.setdefault("shift_y", 512.0) rope_opts.setdefault("shift_y", pe_len)
rope_opts.setdefault("shift_x", 512.0) rope_opts.setdefault("shift_x", pe_len)
transformer_options["rope_options"] = rope_opts transformer_options["rope_options"] = rope_opts
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)

View File

@ -1028,12 +1028,19 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
grid = e.get("extra", None) grid = e.get("extra", None)
start = e.get("index") start = e.get("index")
if position_ids is None: if position_ids is None:
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long)
position_ids[:, :start] = torch.arange(0, start, device=embeds.device) position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start end = e.get("size") + start
len_max = int(grid.max()) // 2 len_max = int(grid.max()) // 2
start_next = len_max + start start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device) if attention_mask is not None:
# Assign compact sequential positions to attended tokens only,
# skipping over padding so post-padding tokens aren't inflated.
after_mask = attention_mask[0, end:]
text_positions = after_mask.cumsum(0) - 1 + start_next + offset
position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:])
else:
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
position_ids[0, start:end] = start + offset position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2 max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]

View File

@ -64,7 +64,13 @@ class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer):
return [output] return [output]
IMAGE_PAD_TOKEN_ID = 151655
class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
T2I_PREFIX = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
EDIT_PREFIX = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
SUFFIX = "<|im_end|>\n<|im_start|>assistant\n"
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__( super().__init__(
embedding_directory=embedding_directory, embedding_directory=embedding_directory,
@ -72,10 +78,8 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
name="qwen25_7b", name="qwen25_7b",
tokenizer=LongCatImageBaseTokenizer, tokenizer=LongCatImageBaseTokenizer,
) )
self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text, return_word_ids=False, images=None, **kwargs):
skip_template = False skip_template = False
if text.startswith("<|im_start|>"): if text.startswith("<|im_start|>"):
skip_template = True skip_template = True
@ -90,11 +94,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
) )
else: else:
has_images = images is not None and len(images) > 0
template_prefix = self.EDIT_PREFIX if has_images else self.T2I_PREFIX
prefix_ids = base_tok.tokenizer( prefix_ids = base_tok.tokenizer(
self.longcat_template_prefix, add_special_tokens=False template_prefix, add_special_tokens=False
)["input_ids"] )["input_ids"]
suffix_ids = base_tok.tokenizer( suffix_ids = base_tok.tokenizer(
self.longcat_template_suffix, add_special_tokens=False self.SUFFIX, add_special_tokens=False
)["input_ids"] )["input_ids"]
prompt_tokens = base_tok.tokenize_with_weights( prompt_tokens = base_tok.tokenize_with_weights(
@ -106,6 +113,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
suffix_pairs = [(t, 1.0) for t in suffix_ids] suffix_pairs = [(t, 1.0) for t in suffix_ids]
combined = prefix_pairs + prompt_pairs + suffix_pairs combined = prefix_pairs + prompt_pairs + suffix_pairs
if has_images:
embed_count = 0
for i in range(len(combined)):
if combined[i][0] == IMAGE_PAD_TOKEN_ID and embed_count < len(images):
combined[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"}, combined[i][1])
embed_count += 1
tokens = {"qwen25_7b": [combined]} tokens = {"qwen25_7b": [combined]}
return tokens return tokens

View File

@ -425,4 +425,7 @@ class Qwen2VLVisionTransformer(nn.Module):
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention) hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
hidden_states = self.merger(hidden_states) hidden_states = self.merger(hidden_states)
# Potentially important for spatially precise edits. This is present in the HF implementation.
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states return hidden_states