mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-09 00:32:31 +08:00
Fix forward pass for Qwen25_7BVLI when attention_mask is not None. Needed for LongCat-Image edit model.
This commit is contained in:
parent
e00ae62907
commit
7b308e22c1
@ -1028,12 +1028,19 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
grid = e.get("extra", None)
|
||||
start = e.get("index")
|
||||
if position_ids is None:
|
||||
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||
position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long)
|
||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||
if attention_mask is not None:
|
||||
# Assign compact sequential positions to attended tokens only,
|
||||
# skipping over padding so post-padding tokens aren't inflated.
|
||||
after_mask = attention_mask[0, end:]
|
||||
text_positions = after_mask.cumsum(0) - 1 + start_next + offset
|
||||
position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:])
|
||||
else:
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||
position_ids[0, start:end] = start + offset
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user