diff --git a/comfy/model_management.py b/comfy/model_management.py index c5b817b62..146c00925 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -332,6 +332,7 @@ except: SUPPORT_FP8_OPS = args.supports_fp8_compute try: if is_amd(): + torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD try: rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) except: @@ -925,11 +926,7 @@ def vae_dtype(device=None, allowed_dtypes=[]): if d == torch.float16 and should_use_fp16(device): return d - # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32 - # slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3 - # also a problem on RDNA4 except fp32 is also slow there. - # This is due to large bf16 convolutions being extremely slow. - if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device): + if d == torch.bfloat16 and should_use_bf16(device): return d return torch.float32