bf16 selection for AMD

This commit is contained in:
doctorpangloss 2024-05-22 22:45:15 -07:00
parent 35cf996b68
commit a79ccd625f
3 changed files with 19 additions and 2 deletions

View File

@ -80,6 +80,7 @@ def create_parser() -> argparse.ArgumentParser:
fp_group.add_argument("--force-fp32", action="store_true",
help="Force fp32 (If this makes your GPU work better please report it).")
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
fp_group.add_argument("--force-bf16", action="store_true", help="Force bf16.")
fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--bf16-unet", action="store_true",

View File

@ -40,6 +40,7 @@ class Configuration(dict):
force_upcast_attention (bool): Force upcasting of attention.
force_fp32 (bool): Force using FP32 precision.
force_fp16 (bool): Force using FP16 precision.
force_bf16 (bool): Force using BF16 precision.
bf16_unet (bool): Use BF16 precision for UNet.
fp16_unet (bool): Use FP16 precision for UNet.
fp8_e4m3fn_unet (bool): Use FP8 precision (e4m3fn variant) for UNet.
@ -110,6 +111,7 @@ class Configuration(dict):
self.force_upcast_attention: bool = False
self.force_fp32: bool = False
self.force_fp16: bool = False
self.force_bf16: bool = False
self.bf16_unet: bool = False
self.fp16_unet: bool = False
self.fp8_e4m3fn_unet: bool = False

View File

@ -189,6 +189,12 @@ def is_nvidia():
return True
return False
def is_amd():
global cpu_state
if cpu_state == CPUState.GPU:
if torch.version.hip:
return True
return False
ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention:
@ -198,7 +204,7 @@ if args.use_pytorch_cross_attention:
VAE_DTYPE = torch.float32
try:
if is_nvidia():
if is_nvidia() or is_amd():
torch_version = torch.version.__version__
if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
@ -239,6 +245,7 @@ elif args.highvram or args.gpu_only:
FORCE_FP32 = False
FORCE_FP16 = False
FORCE_BF16 = False
if args.force_fp32:
logging.info("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True
@ -247,6 +254,10 @@ if args.force_fp16 or cpu_state == CPUState.MPS:
logging.info("Forcing FP16.")
FORCE_FP16 = True
if args.force_bf16:
logging.info("Force BF16")
FORCE_BF16 = True
if lowvram_available:
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to
@ -835,7 +846,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_intel_xpu():
return True
if torch.version.hip:
if is_amd():
return True
props = torch.cuda.get_device_properties("cuda")
@ -872,6 +883,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
if FORCE_BF16:
return True
if device is not None:
if is_device_cpu(device): # TODO ? bf16 works on CPU but is extremely slow
return False