mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-09 08:42:30 +08:00
Fix forward pass for Qwen25_7BVLI when attention_mask is not None. Needed for LongCat-Image edit model.
This commit is contained in:
parent
e00ae62907
commit
7b308e22c1
@ -1028,12 +1028,19 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
|
|||||||
grid = e.get("extra", None)
|
grid = e.get("extra", None)
|
||||||
start = e.get("index")
|
start = e.get("index")
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long)
|
||||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||||
end = e.get("size") + start
|
end = e.get("size") + start
|
||||||
len_max = int(grid.max()) // 2
|
len_max = int(grid.max()) // 2
|
||||||
start_next = len_max + start
|
start_next = len_max + start
|
||||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
if attention_mask is not None:
|
||||||
|
# Assign compact sequential positions to attended tokens only,
|
||||||
|
# skipping over padding so post-padding tokens aren't inflated.
|
||||||
|
after_mask = attention_mask[0, end:]
|
||||||
|
text_positions = after_mask.cumsum(0) - 1 + start_next + offset
|
||||||
|
position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:])
|
||||||
|
else:
|
||||||
|
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||||
position_ids[0, start:end] = start + offset
|
position_ids[0, start:end] = start + offset
|
||||||
max_d = int(grid[0][1]) // 2
|
max_d = int(grid[0][1]) // 2
|
||||||
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user