Fix forward pass for Qwen25_7BVLI when attention_mask is not None. Needed for LongCat-Image edit model.

This commit is contained in:
Talmaj Marinc 2026-03-13 21:53:49 +01:00 committed by Talmaj Marinc
parent e00ae62907
commit 7b308e22c1

View File

@ -1028,12 +1028,19 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
grid = e.get("extra", None) grid = e.get("extra", None)
start = e.get("index") start = e.get("index")
if position_ids is None: if position_ids is None:
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long)
position_ids[:, :start] = torch.arange(0, start, device=embeds.device) position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start end = e.get("size") + start
len_max = int(grid.max()) // 2 len_max = int(grid.max()) // 2
start_next = len_max + start start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device) if attention_mask is not None:
# Assign compact sequential positions to attended tokens only,
# skipping over padding so post-padding tokens aren't inflated.
after_mask = attention_mask[0, end:]
text_positions = after_mask.cumsum(0) - 1 + start_next + offset
position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:])
else:
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
position_ids[0, start:end] = start + offset position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2 max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]