From e1474150de36b5b6477ce42c2a2801577ad42fff Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Mar 2025 04:37:58 -0500 Subject: [PATCH 1/5] Support fp8_scaled diffusion models that don't use fp8 matrix mult. --- comfy/model_base.py | 2 +- comfy/model_detection.py | 4 ++++ comfy/ops.py | 4 +++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index a304c58bd..2fa1ee911 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -108,7 +108,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: - fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None) + fp8 = model_config.optimizations.get("fp8", False) operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) else: operations = model_config.custom_operations diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 1aef549f4..403da5855 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -471,6 +471,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal model_config.scaled_fp8 = scaled_fp8_weight.dtype if model_config.scaled_fp8 == torch.float32: model_config.scaled_fp8 = torch.float8_e4m3fn + if scaled_fp8_weight.nelement() == 2: + model_config.optimizations["fp8"] = False + else: + model_config.optimizations["fp8"] = True return model_config diff --git a/comfy/ops.py b/comfy/ops.py index 358c6ec60..3303c6fcd 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -17,6 +17,7 @@ """ import torch +import logging import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float @@ -308,6 +309,7 @@ class fp8_ops(manual_cast): return torch.nn.functional.linear(input, weight, bias) def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): + logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) class scaled_fp8_op(manual_cast): class Linear(manual_cast.Linear): def __init__(self, *args, **kwargs): @@ -358,7 +360,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: - return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8) + return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=True, override_dtype=scaled_fp8) if ( fp8_compute and From 70e15fd743e85554f907cef164703fce1715cd7d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Mar 2025 04:49:20 -0500 Subject: [PATCH 2/5] No need for scale_input when fp8 matrix mult is disabled. --- comfy/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 3303c6fcd..ced461011 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -360,7 +360,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: - return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=True, override_dtype=scaled_fp8) + return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) if ( fp8_compute and From 11b1f27cb17938bbb2f723f8d71ac78bb9f2e40f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Mar 2025 04:52:36 -0500 Subject: [PATCH 3/5] Set WAN default compute dtype to fp16. --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 7157a15f2..b4d7bfe20 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -931,7 +931,7 @@ class WAN21_T2V(supported_models_base.BASE): memory_usage_factor = 1.0 - supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."] From 4ab1875283ce985e77be7ffb4b499db11d937f73 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Mar 2025 07:45:40 -0500 Subject: [PATCH 4/5] Add .bat file to nightly package to run with fp16 accumulation. --- .../run_nvidia_gpu_fast_fp16_accumulation.bat | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .ci/windows_nightly_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat diff --git a/.ci/windows_nightly_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_nightly_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat new file mode 100644 index 000000000..38f06ecb2 --- /dev/null +++ b/.ci/windows_nightly_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat @@ -0,0 +1,2 @@ +.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation +pause From 5dbd25096513838785143c493b94e6c518e71c0b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 7 Mar 2025 07:57:59 -0500 Subject: [PATCH 5/5] Update nightly instructions in readme. --- .github/workflows/windows_release_nightly_pytorch.yml | 4 ++-- README.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index f90488705..cea9aae17 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -7,7 +7,7 @@ on: description: 'cuda version' required: true type: string - default: "126" + default: "128" python_minor: description: 'python minor version' @@ -19,7 +19,7 @@ on: description: 'python patch version' required: true type: string - default: "1" + default: "2" # push: # branches: # - master diff --git a/README.md b/README.md index 9190dd493..a807ea9d6 100644 --- a/README.md +++ b/README.md @@ -215,9 +215,9 @@ Nvidia users should install stable pytorch using this command: ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu126``` -This is the command to install pytorch nightly instead which might have performance improvements: +This is the command to install pytorch nightly instead which supports the new blackwell 50xx series GPUs and might have performance improvements. -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126``` +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128``` #### Troubleshooting