mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-15 03:27:24 +08:00
224 lines
8.9 KiB
Python
224 lines
8.9 KiB
Python
"""HiDream-O1 input-prep helpers: image/resolution math and unified-sequence
|
|
RoPE position-id assembly. The fix_point offset in get_rope_index_fix_point
|
|
lets the target image and patchified ref images share spatial RoPE positions
|
|
despite living at different sequence indices — same 2D image plane.
|
|
"""
|
|
|
|
import math
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from PIL import Image
|
|
|
|
|
|
PREDEFINED_RESOLUTIONS = [
|
|
(2048, 2048),
|
|
(2304, 1728),
|
|
(1728, 2304),
|
|
(2560, 1440),
|
|
(1440, 2560),
|
|
(2496, 1664),
|
|
(1664, 2496),
|
|
(3104, 1312),
|
|
(1312, 3104),
|
|
(2304, 1792),
|
|
(1792, 2304),
|
|
]
|
|
|
|
PATCH_SIZE = 32
|
|
CONDITION_IMAGE_SIZE = 384 # ViT-side base size for ref images
|
|
|
|
|
|
def find_closest_resolution(width, height):
|
|
"""Closest (W, H) in PREDEFINED_RESOLUTIONS by aspect ratio."""
|
|
img_ratio = width / height
|
|
best = None
|
|
min_diff = float("inf")
|
|
for w, h in PREDEFINED_RESOLUTIONS:
|
|
diff = abs(w / h - img_ratio)
|
|
if diff < min_diff:
|
|
min_diff = diff
|
|
best = (w, h)
|
|
return best
|
|
|
|
|
|
def resize_pilimage(pil_image, image_size, patch_size=16, resampler=Image.BICUBIC):
|
|
"""Resize to fit image_size**2 area, patch-aligned, center-cropped. Pre-halves
|
|
with BOX filter while the image is still very large.
|
|
"""
|
|
while min(*pil_image.size) >= 2 * image_size:
|
|
pil_image = pil_image.resize(
|
|
tuple(x // 2 for x in pil_image.size), resample=Image.BOX,
|
|
)
|
|
|
|
m = patch_size
|
|
width, height = pil_image.width, pil_image.height
|
|
s_max = image_size * image_size
|
|
scale = math.sqrt(s_max / (width * height))
|
|
|
|
candidates = [
|
|
(round(width * scale) // m * m, round(height * scale) // m * m),
|
|
(round(width * scale) // m * m, math.floor(height * scale) // m * m),
|
|
(math.floor(width * scale) // m * m, round(height * scale) // m * m),
|
|
(math.floor(width * scale) // m * m, math.floor(height * scale) // m * m),
|
|
]
|
|
candidates = sorted(candidates, key=lambda x: x[0] * x[1], reverse=True)
|
|
new_size = candidates[-1]
|
|
for c in candidates:
|
|
if c[0] * c[1] <= s_max:
|
|
new_size = c
|
|
break
|
|
|
|
s1 = width / new_size[0]
|
|
s2 = height / new_size[1]
|
|
if s1 < s2:
|
|
pil_image = pil_image.resize([new_size[0], round(height / s1)], resample=resampler)
|
|
top = (round(height / s1) - new_size[1]) // 2
|
|
pil_image = pil_image.crop((0, top, new_size[0], top + new_size[1]))
|
|
else:
|
|
pil_image = pil_image.resize([round(width / s2), new_size[1]], resample=resampler)
|
|
left = (round(width / s2) - new_size[0]) // 2
|
|
pil_image = pil_image.crop((left, 0, left + new_size[0], new_size[1]))
|
|
return pil_image
|
|
|
|
|
|
def calculate_dimensions(max_size, ratio):
|
|
"""(W, H) for an aspect ratio fitting in max_size**2 area, 32-aligned."""
|
|
width = math.sqrt(max_size * max_size * ratio)
|
|
height = width / ratio
|
|
width = int(width / 32) * 32
|
|
height = int(height / 32) * 32
|
|
return width, height
|
|
|
|
|
|
def ref_max_size(target_max_dim, k):
|
|
"""K-dependent ref-image max dim before patchifying."""
|
|
if k == 1:
|
|
return target_max_dim
|
|
if k == 2:
|
|
return target_max_dim * 48 // 64
|
|
if k <= 4:
|
|
return target_max_dim // 2
|
|
if k <= 8:
|
|
return target_max_dim * 24 // 64
|
|
return target_max_dim // 4
|
|
|
|
|
|
def cond_image_size(k):
|
|
"""K-dependent ViT-side image size."""
|
|
if k <= 4:
|
|
return CONDITION_IMAGE_SIZE
|
|
if k <= 8:
|
|
return CONDITION_IMAGE_SIZE * 48 // 64
|
|
return CONDITION_IMAGE_SIZE // 2
|
|
|
|
|
|
def get_rope_index_fix_point(
|
|
spatial_merge_size: int,
|
|
image_token_id: int,
|
|
video_token_id: int,
|
|
vision_start_token_id: int,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
skip_vision_start_token=None,
|
|
fix_point: int = 4096,
|
|
):
|
|
if video_grid_thw is not None:
|
|
video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
|
|
video_grid_thw[:, 0] = 1
|
|
|
|
mrope_position_deltas = []
|
|
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
|
|
total_input_ids = input_ids
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones_like(total_input_ids)
|
|
position_ids = torch.ones(
|
|
3, input_ids.shape[0], input_ids.shape[1],
|
|
dtype=input_ids.dtype, device=input_ids.device,
|
|
)
|
|
image_index, video_index = 0, 0
|
|
attention_mask = attention_mask.to(total_input_ids.device)
|
|
for i, input_ids_b in enumerate(total_input_ids):
|
|
input_ids_b = input_ids_b[attention_mask[i] == 1]
|
|
vision_start_indices = torch.argwhere(input_ids_b == vision_start_token_id).squeeze(1)
|
|
vision_tokens = input_ids_b[vision_start_indices + 1]
|
|
image_nums = (vision_tokens == image_token_id).sum()
|
|
video_nums = (vision_tokens == video_token_id).sum()
|
|
input_tokens = input_ids_b.tolist()
|
|
llm_pos_ids_list = []
|
|
st = 0
|
|
remain_images, remain_videos = image_nums, video_nums
|
|
for _ in range(image_nums + video_nums):
|
|
if image_token_id in input_tokens and remain_images > 0:
|
|
ed_image = input_tokens.index(image_token_id, st)
|
|
else:
|
|
ed_image = len(input_tokens) + 1
|
|
if video_token_id in input_tokens and remain_videos > 0:
|
|
ed_video = input_tokens.index(video_token_id, st)
|
|
else:
|
|
ed_video = len(input_tokens) + 1
|
|
if ed_image < ed_video:
|
|
t = image_grid_thw[image_index][0]
|
|
h = image_grid_thw[image_index][1]
|
|
w = image_grid_thw[image_index][2]
|
|
image_index += 1
|
|
remain_images -= 1
|
|
ed = ed_image
|
|
else:
|
|
t = video_grid_thw[video_index][0]
|
|
h = video_grid_thw[video_index][1]
|
|
w = video_grid_thw[video_index][2]
|
|
video_index += 1
|
|
remain_videos -= 1
|
|
ed = ed_video
|
|
llm_grid_t = t.item()
|
|
llm_grid_h = h.item() // spatial_merge_size
|
|
llm_grid_w = w.item() // spatial_merge_size
|
|
text_len = ed - st
|
|
text_len -= skip_vision_start_token[image_index - 1]
|
|
text_len = max(0, text_len)
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
|
|
|
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
|
|
|
if skip_vision_start_token[image_index - 1]:
|
|
if fix_point > 0:
|
|
fix_point = fix_point - st_idx
|
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + fix_point + st_idx)
|
|
fix_point = 0
|
|
else:
|
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
|
|
|
if st < len(input_tokens):
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
text_len = len(input_tokens) - st
|
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
|
|
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
|
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
|
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
|
|
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
|
return position_ids, mrope_position_deltas
|
|
|
|
if attention_mask is not None:
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
|
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
|
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
|
else:
|
|
position_ids = (
|
|
torch.arange(input_ids.shape[1], device=input_ids.device)
|
|
.view(1, 1, -1).expand(3, input_ids.shape[0], -1)
|
|
)
|
|
mrope_position_deltas = torch.zeros(
|
|
[input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype,
|
|
)
|
|
return position_ids, mrope_position_deltas
|