Improve AMD performance. (#10302)

I honestly have no idea why this improves things but it does.
This commit is contained in:
comfyanonymous 2025-10-11 21:28:01 -07:00 committed by GitHub
parent 84e9ce32c6
commit a125cd84b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -332,6 +332,7 @@ except:
SUPPORT_FP8_OPS = args.supports_fp8_compute SUPPORT_FP8_OPS = args.supports_fp8_compute
try: try:
if is_amd(): if is_amd():
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
try: try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except: except:
@ -925,11 +926,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
if d == torch.float16 and should_use_fp16(device): if d == torch.float16 and should_use_fp16(device):
return d return d
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32 if d == torch.bfloat16 and should_use_bf16(device):
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
# also a problem on RDNA4 except fp32 is also slow there.
# This is due to large bf16 convolutions being extremely slow.
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
return d return d
return torch.float32 return torch.float32