Fix fp16 autocast RuntimeError on Macs.

This commit is contained in:
shiimizu 2023-09-03 02:16:09 -07:00
parent a74c5dbf37
commit 4b471cb8c2

View File

@ -466,7 +466,8 @@ def vae_dtype():
def get_autocast_device(dev):
if hasattr(dev, 'type'):
return dev.type
return dev.type if not (dev.type == "mps" and should_use_fp16(
dev, prioritize_performance=False)) else "cpu"
return "cuda"