diff --git a/comfy/cli_args.py b/comfy/cli_args.py index ef57a50c7..912393da8 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -80,6 +80,7 @@ def create_parser() -> argparse.ArgumentParser: fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") + fp_group.add_argument("--force-bf16", action="store_true", help="Force bf16.") fpunet_group = parser.add_mutually_exclusive_group() fpunet_group.add_argument("--bf16-unet", action="store_true", diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 974663f06..860a3c800 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -40,6 +40,7 @@ class Configuration(dict): force_upcast_attention (bool): Force upcasting of attention. force_fp32 (bool): Force using FP32 precision. force_fp16 (bool): Force using FP16 precision. + force_bf16 (bool): Force using BF16 precision. bf16_unet (bool): Use BF16 precision for UNet. fp16_unet (bool): Use FP16 precision for UNet. fp8_e4m3fn_unet (bool): Use FP8 precision (e4m3fn variant) for UNet. @@ -110,6 +111,7 @@ class Configuration(dict): self.force_upcast_attention: bool = False self.force_fp32: bool = False self.force_fp16: bool = False + self.force_bf16: bool = False self.bf16_unet: bool = False self.fp16_unet: bool = False self.fp8_e4m3fn_unet: bool = False diff --git a/comfy/model_management.py b/comfy/model_management.py index d57d4efe2..5b68175f8 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -189,6 +189,12 @@ def is_nvidia(): return True return False +def is_amd(): + global cpu_state + if cpu_state == CPUState.GPU: + if torch.version.hip: + return True + return False ENABLE_PYTORCH_ATTENTION = False if args.use_pytorch_cross_attention: @@ -198,7 +204,7 @@ if args.use_pytorch_cross_attention: VAE_DTYPE = torch.float32 try: - if is_nvidia(): + if is_nvidia() or is_amd(): torch_version = torch.version.__version__ if int(torch_version[0]) >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: @@ -239,6 +245,7 @@ elif args.highvram or args.gpu_only: FORCE_FP32 = False FORCE_FP16 = False +FORCE_BF16 = False if args.force_fp32: logging.info("Forcing FP32, if this improves things please report it.") FORCE_FP32 = True @@ -247,6 +254,10 @@ if args.force_fp16 or cpu_state == CPUState.MPS: logging.info("Forcing FP16.") FORCE_FP16 = True +if args.force_bf16: + logging.info("Force BF16") + FORCE_BF16 = True + if lowvram_available: if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): vram_state = set_vram_to @@ -835,7 +846,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if is_intel_xpu(): return True - if torch.version.hip: + if is_amd(): return True props = torch.cuda.get_device_properties("cuda") @@ -872,6 +883,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): + if FORCE_BF16: + return True + if device is not None: if is_device_cpu(device): # TODO ? bf16 works on CPU but is extremely slow return False