From 7f374e42c833c69c71605507b90f79cc26d14a71 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 31 Oct 2025 12:41:40 -0700 Subject: [PATCH] ScaleROPE now works on Lumina models. (#10578) --- comfy/ldm/lumina/model.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index f87d98ac0..b4494a51d 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -522,7 +522,7 @@ class NextDiT(nn.Module): max_cap_len = max(l_effective_cap_len) max_img_len = max(l_effective_img_len) - position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) + position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device) for i in range(bsz): cap_len = l_effective_cap_len[i] @@ -531,10 +531,22 @@ class NextDiT(nn.Module): H_tokens, W_tokens = H // pH, W // pW assert H_tokens * W_tokens == img_len - position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) + rope_options = transformer_options.get("rope_options", None) + h_scale = 1.0 + w_scale = 1.0 + h_start = 0 + w_start = 0 + if rope_options is not None: + h_scale = rope_options.get("scale_y", 1.0) + w_scale = rope_options.get("scale_x", 1.0) + + h_start = rope_options.get("shift_y", 0.0) + w_start = rope_options.get("shift_x", 0.0) + + position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device) position_ids[i, cap_len:cap_len+img_len, 0] = cap_len - row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() - col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() + row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() + col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() position_ids[i, cap_len:cap_len+img_len, 1] = row_ids position_ids[i, cap_len:cap_len+img_len, 2] = col_ids