fix: use float32 for Qwen image rotary embeddings

Fixes the error when using torch.compile with bfloat16 inference:
'No backend can handle apply_rope1: eager: freqs_cis: dtype torch.bfloat16 not in {torch.float32}'

The apply_rope1 function uses addcmul_ in-place operations that fail
under torch.compile when freqs_cis is bfloat16. This restores the
behavior from commit 4cd88186 which was inadvertently reverted in
commit c4a6b389.
This commit is contained in:
Hunter Senft-Grupp 2026-01-06 22:42:55 -05:00
parent edee33f55e
commit 68df63096b

View File

@ -473,7 +473,7 @@ class QwenImageTransformer2DModel(nn.Module):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
image_rotary_emb = self.pe_embedder(ids).to(torch.float32).contiguous()
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states)