mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
bf16 selection for AMD
This commit is contained in:
parent
35cf996b68
commit
a79ccd625f
@ -80,6 +80,7 @@ def create_parser() -> argparse.ArgumentParser:
|
||||
fp_group.add_argument("--force-fp32", action="store_true",
|
||||
help="Force fp32 (If this makes your GPU work better please report it).")
|
||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||
fp_group.add_argument("--force-bf16", action="store_true", help="Force bf16.")
|
||||
|
||||
fpunet_group = parser.add_mutually_exclusive_group()
|
||||
fpunet_group.add_argument("--bf16-unet", action="store_true",
|
||||
|
||||
@ -40,6 +40,7 @@ class Configuration(dict):
|
||||
force_upcast_attention (bool): Force upcasting of attention.
|
||||
force_fp32 (bool): Force using FP32 precision.
|
||||
force_fp16 (bool): Force using FP16 precision.
|
||||
force_bf16 (bool): Force using BF16 precision.
|
||||
bf16_unet (bool): Use BF16 precision for UNet.
|
||||
fp16_unet (bool): Use FP16 precision for UNet.
|
||||
fp8_e4m3fn_unet (bool): Use FP8 precision (e4m3fn variant) for UNet.
|
||||
@ -110,6 +111,7 @@ class Configuration(dict):
|
||||
self.force_upcast_attention: bool = False
|
||||
self.force_fp32: bool = False
|
||||
self.force_fp16: bool = False
|
||||
self.force_bf16: bool = False
|
||||
self.bf16_unet: bool = False
|
||||
self.fp16_unet: bool = False
|
||||
self.fp8_e4m3fn_unet: bool = False
|
||||
|
||||
@ -189,6 +189,12 @@ def is_nvidia():
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_amd():
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.GPU:
|
||||
if torch.version.hip:
|
||||
return True
|
||||
return False
|
||||
|
||||
ENABLE_PYTORCH_ATTENTION = False
|
||||
if args.use_pytorch_cross_attention:
|
||||
@ -198,7 +204,7 @@ if args.use_pytorch_cross_attention:
|
||||
VAE_DTYPE = torch.float32
|
||||
|
||||
try:
|
||||
if is_nvidia():
|
||||
if is_nvidia() or is_amd():
|
||||
torch_version = torch.version.__version__
|
||||
if int(torch_version[0]) >= 2:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
@ -239,6 +245,7 @@ elif args.highvram or args.gpu_only:
|
||||
|
||||
FORCE_FP32 = False
|
||||
FORCE_FP16 = False
|
||||
FORCE_BF16 = False
|
||||
if args.force_fp32:
|
||||
logging.info("Forcing FP32, if this improves things please report it.")
|
||||
FORCE_FP32 = True
|
||||
@ -247,6 +254,10 @@ if args.force_fp16 or cpu_state == CPUState.MPS:
|
||||
logging.info("Forcing FP16.")
|
||||
FORCE_FP16 = True
|
||||
|
||||
if args.force_bf16:
|
||||
logging.info("Force BF16")
|
||||
FORCE_BF16 = True
|
||||
|
||||
if lowvram_available:
|
||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||
vram_state = set_vram_to
|
||||
@ -835,7 +846,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
|
||||
if torch.version.hip:
|
||||
if is_amd():
|
||||
return True
|
||||
|
||||
props = torch.cuda.get_device_properties("cuda")
|
||||
@ -872,6 +883,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
|
||||
|
||||
def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||
if FORCE_BF16:
|
||||
return True
|
||||
|
||||
if device is not None:
|
||||
if is_device_cpu(device): # TODO ? bf16 works on CPU but is extremely slow
|
||||
return False
|
||||
|
||||
Loading…
Reference in New Issue
Block a user