diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 824daf5e6..6d0aed827 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -16,7 +16,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 - if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): + if not comfy.model_management.supports_fp64(pos.device): device = torch.device("cpu") else: device = pos.device diff --git a/comfy/model_management.py b/comfy/model_management.py index 0eebf1ded..bcf1399c4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1732,6 +1732,21 @@ def supports_mxfp8_compute(device=None): return True +def supports_fp64(device=None): + if is_device_mps(device): + return False + + if is_intel_xpu(): + return False + + if is_directml_enabled(): + return False + + if is_ixuca(): + return False + + return True + def extended_fp16_support(): # TODO: check why some models work with fp16 on newer torch versions but not on older if torch_version_numeric < (2, 7):