mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
ScaleROPE now works on Lumina models. (#10578)
This commit is contained in:
parent
27d1bd8829
commit
7f374e42c8
@ -522,7 +522,7 @@ class NextDiT(nn.Module):
|
|||||||
max_cap_len = max(l_effective_cap_len)
|
max_cap_len = max(l_effective_cap_len)
|
||||||
max_img_len = max(l_effective_img_len)
|
max_img_len = max(l_effective_img_len)
|
||||||
|
|
||||||
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
|
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
for i in range(bsz):
|
for i in range(bsz):
|
||||||
cap_len = l_effective_cap_len[i]
|
cap_len = l_effective_cap_len[i]
|
||||||
@ -531,10 +531,22 @@ class NextDiT(nn.Module):
|
|||||||
H_tokens, W_tokens = H // pH, W // pW
|
H_tokens, W_tokens = H // pH, W // pW
|
||||||
assert H_tokens * W_tokens == img_len
|
assert H_tokens * W_tokens == img_len
|
||||||
|
|
||||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
rope_options = transformer_options.get("rope_options", None)
|
||||||
|
h_scale = 1.0
|
||||||
|
w_scale = 1.0
|
||||||
|
h_start = 0
|
||||||
|
w_start = 0
|
||||||
|
if rope_options is not None:
|
||||||
|
h_scale = rope_options.get("scale_y", 1.0)
|
||||||
|
w_scale = rope_options.get("scale_x", 1.0)
|
||||||
|
|
||||||
|
h_start = rope_options.get("shift_y", 0.0)
|
||||||
|
w_start = rope_options.get("shift_x", 0.0)
|
||||||
|
|
||||||
|
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
|
||||||
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
||||||
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||||
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||||
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
||||||
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user