diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a906ff1c0..10c142e67 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -130,7 +130,7 @@ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") -parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.") +parser.add_argument("--fast", metavar="number", type=int, const=99, default=0, nargs="?", help="Enable some untested and potentially quality deteriorating optimizations. You can pass a number from 0 to 10 for a bigger speed vs quality tradeoff. Using --fast with no number means maximum speed. 2 or larger enables fp16 accumulation, 5 or larger enables fp8 matrix multiplication.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ee29251b9..ceb24c852 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -418,10 +418,7 @@ def controlnet_config(sd, model_options={}): weight_dtype = comfy.utils.weight_dtype(sd) supported_inference_dtypes = list(model_config.supported_inference_dtypes) - if weight_dtype is not None: - supported_inference_dtypes.append(weight_dtype) - - unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes) + unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype) load_device = comfy.model_management.get_torch_device() manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) @@ -689,10 +686,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): if supported_inference_dtypes is None: supported_inference_dtypes = [comfy.model_management.unet_dtype()] - if weight_dtype is not None: - supported_inference_dtypes.append(weight_dtype) - - unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes) + unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype) load_device = comfy.model_management.get_torch_device() diff --git a/comfy/model_management.py b/comfy/model_management.py index a9e10bb46..19a204cb2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -281,9 +281,10 @@ except: PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other try: - if is_nvidia() and args.fast: + if is_nvidia() and args.fast >= 2: torch.backends.cuda.matmul.allow_fp16_accumulation = True PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance + logging.info("Enabled fp16 accumulation.") except: pass @@ -675,7 +676,7 @@ def unet_inital_load_device(parameters, dtype): def maximum_vram_for_weights(device=None): return (get_total_memory(device) * 0.88 - minimum_inference_memory()) -def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): +def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32], weight_dtype=None): if model_params < 0: model_params = 1000000000000000000000 if args.fp32_unet: @@ -693,10 +694,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor fp8_dtype = None try: - for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - if dtype in supported_dtypes: - fp8_dtype = dtype - break + if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + fp8_dtype = weight_dtype except: pass @@ -708,7 +707,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor if model_params * 2 > free_model_memory: return fp8_dtype - if PRIORITIZE_FP16: + if PRIORITIZE_FP16 or weight_dtype == torch.float16: if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params): return torch.float16 @@ -744,6 +743,9 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo return None fp16_supported = should_use_fp16(inference_device, prioritize_performance=True) + if PRIORITIZE_FP16 and fp16_supported and torch.float16 in supported_dtypes: + return torch.float16 + for dt in supported_dtypes: if dt == torch.float16 and fp16_supported: return torch.float16 diff --git a/comfy/ops.py b/comfy/ops.py index 30014477e..905ea90f6 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -360,7 +360,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8) - if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8: + if fp8_compute and (fp8_optimizations or args.fast >= 5) and not disable_fast_fp8: return fp8_ops if compute_dtype is None or weight_dtype == compute_dtype: diff --git a/comfy/sd.py b/comfy/sd.py index 640253b09..21913cf3e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -896,14 +896,14 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return None unet_weight_dtype = list(model_config.supported_inference_dtypes) - if weight_dtype is not None and model_config.scaled_fp8 is None: - unet_weight_dtype.append(weight_dtype) + if model_config.scaled_fp8 is not None: + weight_dtype = None model_config.custom_operations = model_options.get("custom_operations", None) unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None)) if unet_dtype is None: - unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype) + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) @@ -994,11 +994,11 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse offload_device = model_management.unet_offload_device() unet_weight_dtype = list(model_config.supported_inference_dtypes) - if weight_dtype is not None and model_config.scaled_fp8 is None: - unet_weight_dtype.append(weight_dtype) + if model_config.scaled_fp8 is not None: + weight_dtype = None if dtype is None: - unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype) + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype) else: unet_dtype = dtype