diff --git a/comfy/model_management.py b/comfy/model_management.py index aca8af999..06e2b5934 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -466,7 +466,8 @@ def vae_dtype(): def get_autocast_device(dev): if hasattr(dev, 'type'): - return dev.type + return dev.type if not (dev.type == "mps" and should_use_fp16( + dev, prioritize_performance=False)) else "cpu" return "cuda"