mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-20 11:20:18 +08:00
bf16 selection for AMD
This commit is contained in:
parent
35cf996b68
commit
a79ccd625f
@ -80,6 +80,7 @@ def create_parser() -> argparse.ArgumentParser:
|
|||||||
fp_group.add_argument("--force-fp32", action="store_true",
|
fp_group.add_argument("--force-fp32", action="store_true",
|
||||||
help="Force fp32 (If this makes your GPU work better please report it).")
|
help="Force fp32 (If this makes your GPU work better please report it).")
|
||||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||||
|
fp_group.add_argument("--force-bf16", action="store_true", help="Force bf16.")
|
||||||
|
|
||||||
fpunet_group = parser.add_mutually_exclusive_group()
|
fpunet_group = parser.add_mutually_exclusive_group()
|
||||||
fpunet_group.add_argument("--bf16-unet", action="store_true",
|
fpunet_group.add_argument("--bf16-unet", action="store_true",
|
||||||
|
|||||||
@ -40,6 +40,7 @@ class Configuration(dict):
|
|||||||
force_upcast_attention (bool): Force upcasting of attention.
|
force_upcast_attention (bool): Force upcasting of attention.
|
||||||
force_fp32 (bool): Force using FP32 precision.
|
force_fp32 (bool): Force using FP32 precision.
|
||||||
force_fp16 (bool): Force using FP16 precision.
|
force_fp16 (bool): Force using FP16 precision.
|
||||||
|
force_bf16 (bool): Force using BF16 precision.
|
||||||
bf16_unet (bool): Use BF16 precision for UNet.
|
bf16_unet (bool): Use BF16 precision for UNet.
|
||||||
fp16_unet (bool): Use FP16 precision for UNet.
|
fp16_unet (bool): Use FP16 precision for UNet.
|
||||||
fp8_e4m3fn_unet (bool): Use FP8 precision (e4m3fn variant) for UNet.
|
fp8_e4m3fn_unet (bool): Use FP8 precision (e4m3fn variant) for UNet.
|
||||||
@ -110,6 +111,7 @@ class Configuration(dict):
|
|||||||
self.force_upcast_attention: bool = False
|
self.force_upcast_attention: bool = False
|
||||||
self.force_fp32: bool = False
|
self.force_fp32: bool = False
|
||||||
self.force_fp16: bool = False
|
self.force_fp16: bool = False
|
||||||
|
self.force_bf16: bool = False
|
||||||
self.bf16_unet: bool = False
|
self.bf16_unet: bool = False
|
||||||
self.fp16_unet: bool = False
|
self.fp16_unet: bool = False
|
||||||
self.fp8_e4m3fn_unet: bool = False
|
self.fp8_e4m3fn_unet: bool = False
|
||||||
|
|||||||
@ -189,6 +189,12 @@ def is_nvidia():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_amd():
|
||||||
|
global cpu_state
|
||||||
|
if cpu_state == CPUState.GPU:
|
||||||
|
if torch.version.hip:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
ENABLE_PYTORCH_ATTENTION = False
|
ENABLE_PYTORCH_ATTENTION = False
|
||||||
if args.use_pytorch_cross_attention:
|
if args.use_pytorch_cross_attention:
|
||||||
@ -198,7 +204,7 @@ if args.use_pytorch_cross_attention:
|
|||||||
VAE_DTYPE = torch.float32
|
VAE_DTYPE = torch.float32
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia():
|
if is_nvidia() or is_amd():
|
||||||
torch_version = torch.version.__version__
|
torch_version = torch.version.__version__
|
||||||
if int(torch_version[0]) >= 2:
|
if int(torch_version[0]) >= 2:
|
||||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
@ -239,6 +245,7 @@ elif args.highvram or args.gpu_only:
|
|||||||
|
|
||||||
FORCE_FP32 = False
|
FORCE_FP32 = False
|
||||||
FORCE_FP16 = False
|
FORCE_FP16 = False
|
||||||
|
FORCE_BF16 = False
|
||||||
if args.force_fp32:
|
if args.force_fp32:
|
||||||
logging.info("Forcing FP32, if this improves things please report it.")
|
logging.info("Forcing FP32, if this improves things please report it.")
|
||||||
FORCE_FP32 = True
|
FORCE_FP32 = True
|
||||||
@ -247,6 +254,10 @@ if args.force_fp16 or cpu_state == CPUState.MPS:
|
|||||||
logging.info("Forcing FP16.")
|
logging.info("Forcing FP16.")
|
||||||
FORCE_FP16 = True
|
FORCE_FP16 = True
|
||||||
|
|
||||||
|
if args.force_bf16:
|
||||||
|
logging.info("Force BF16")
|
||||||
|
FORCE_BF16 = True
|
||||||
|
|
||||||
if lowvram_available:
|
if lowvram_available:
|
||||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||||
vram_state = set_vram_to
|
vram_state = set_vram_to
|
||||||
@ -835,7 +846,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if torch.version.hip:
|
if is_amd():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties("cuda")
|
props = torch.cuda.get_device_properties("cuda")
|
||||||
@ -872,6 +883,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
|
|
||||||
|
|
||||||
def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||||
|
if FORCE_BF16:
|
||||||
|
return True
|
||||||
|
|
||||||
if device is not None:
|
if device is not None:
|
||||||
if is_device_cpu(device): # TODO ? bf16 works on CPU but is extremely slow
|
if is_device_cpu(device): # TODO ? bf16 works on CPU but is extremely slow
|
||||||
return False
|
return False
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user