This commit is contained in:
Yousef Rafat 2025-09-27 14:11:37 +03:00
parent 1d24e63b64
commit c951e8f513

View File

@ -1,9 +1,11 @@
import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any
import math
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
import comfy.ldm.common_dit
import comfy.model_management
@ -401,12 +403,18 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
for e in embeds_info:
if e.get("type") == "image":
grid = e.get("extra", None)
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
start = e.get("index")
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device)
position_ids[0, start:end] = start
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
max_d = int(grid[0][2]) // 2
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
if grid is None:
position_ids = None