This commit is contained in:
Yousef Rafat 2025-09-27 14:11:37 +03:00
parent 1d24e63b64
commit c951e8f513

View File

@ -1,9 +1,11 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any from typing import Optional, Any
import math import math
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
import comfy.ldm.common_dit import comfy.ldm.common_dit
import comfy.model_management import comfy.model_management
@ -401,12 +403,18 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
for e in embeds_info: for e in embeds_info:
if e.get("type") == "image": if e.get("type") == "image":
grid = e.get("extra", None) grid = e.get("extra", None)
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
start = e.get("index") start = e.get("index")
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start end = e.get("size") + start
len_max = int(grid.max()) // 2 len_max = int(grid.max()) // 2
start_next = len_max + start start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device)
position_ids[0, start:end] = start
max_d = int(grid[0][1]) // 2 max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
max_d = int(grid[0][2]) // 2 max_d = int(grid[0][2]) // 2
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
if grid is None: if grid is None:
position_ids = None position_ids = None