mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-08 20:42:32 +08:00
fix: use float32 for Qwen image rotary embeddings
Fixes the error when using torch.compile with bfloat16 inference:
'No backend can handle apply_rope1: eager: freqs_cis: dtype torch.bfloat16 not in {torch.float32}'
The apply_rope1 function uses addcmul_ in-place operations that fail
under torch.compile when freqs_cis is bfloat16. This restores the
behavior from commit 4cd88186 which was inadvertently reverted in
commit c4a6b389.
This commit is contained in:
parent
edee33f55e
commit
68df63096b
@ -473,7 +473,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
image_rotary_emb = self.pe_embedder(ids).to(torch.float32).contiguous()
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user