mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-18 14:32:49 +08:00
Fix ernie on devices that don't support fp64. (#13414)
This commit is contained in:
parent
7ce3f64c78
commit
cb0bbde402
@ -15,7 +15,7 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
|||||||
|
|
||||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim
|
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim
|
||||||
omega = 1.0 / (theta**scale)
|
omega = 1.0 / (theta**scale)
|
||||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
out = torch.einsum("...n,d->...nd", pos.to(device), omega)
|
||||||
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
|
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
|
||||||
return out.to(dtype=torch.float32, device=pos.device)
|
return out.to(dtype=torch.float32, device=pos.device)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user