From 9519e2d49d8208a1b49fe60d238bad53d13cc556 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20Michaj=C5=82ow?= Date: Mon, 26 May 2025 14:13:08 +0200 Subject: [PATCH 1/2] Revert "Disable pytorch attention in VAE for AMD." It causes crashes even without pytorch attention for big sizes, and for resonable sizes it is significantly faster. This reverts commit 1cd6cd608086a8ff8789b747b8d4f8b9273e576e. --- comfy/ldm/modules/diffusionmodules/model.py | 2 +- comfy/model_management.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 4245eedca..a2f984b00 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -297,7 +297,7 @@ def vae_attention(): if model_management.xformers_enabled_vae(): logging.info("Using xformers attention in VAE") return xformers_attention - elif model_management.pytorch_attention_enabled_vae(): + elif model_management.pytorch_attention_enabled(): logging.info("Using pytorch attention in VAE") return pytorch_attention else: diff --git a/comfy/model_management.py b/comfy/model_management.py index a2c318ec3..c6b2c3ac8 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1114,11 +1114,6 @@ def pytorch_attention_enabled(): global ENABLE_PYTORCH_ATTENTION return ENABLE_PYTORCH_ATTENTION -def pytorch_attention_enabled_vae(): - if is_amd(): - return False # enabling pytorch attention on AMD currently causes crash when doing high res - return pytorch_attention_enabled() - def pytorch_attention_flash_attention(): global ENABLE_PYTORCH_ATTENTION if ENABLE_PYTORCH_ATTENTION: From 02d15cc85f5f34322938786788ac1ac2b6cf5961 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20Michaj=C5=82ow?= Date: Mon, 26 May 2025 14:16:38 +0200 Subject: [PATCH 2/2] Enable pytorch attention by default on AMD gfx1200 --- comfy/model_management.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index c6b2c3ac8..7e285aa36 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -347,8 +347,8 @@ try: if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True if rocm_version >= (7, 0): - if any((a in arch) for a in ["gfx1201"]): - ENABLE_PYTORCH_ATTENTION = True + if any((a in arch) for a in ["gfx1200", "gfx1201"]): + ENABLE_PYTORCH_ATTENTION = True if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0 SUPPORT_FP8_OPS = True