From 1a1383e4cdd868e47a5f6d89a5d01a8d6428c8d9 Mon Sep 17 00:00:00 2001 From: Gaeros <163143799+elias-gaeros@users.noreply.github.com> Date: Sat, 24 May 2025 23:11:30 +0200 Subject: [PATCH] fp8: model_config flag for scaled activation --- comfy/model_base.py | 2 +- comfy/model_detection.py | 4 ++-- comfy/ops.py | 20 ++++++++++---------- comfy/supported_models_base.py | 1 + 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index fb4724690..b7304b14b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -116,7 +116,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", False) - operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, scaled_fp8_activation=model_config.scaled_fp8_activation) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 20f287df9..7533793c6 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -556,9 +556,9 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal if model_config.scaled_fp8 == torch.float32: model_config.scaled_fp8 = torch.float8_e4m3fn if scaled_fp8_weight.nelement() == 2: - model_config.optimizations["fp8"] = False + model_config.scaled_fp8_activation = False else: - model_config.optimizations["fp8"] = True + model_config.scaled_fp8_activation = True return model_config diff --git a/comfy/ops.py b/comfy/ops.py index 431c8f89d..c1c23f023 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -358,8 +358,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None if not scale_input: self.scale_input = None - - if not hasattr(self, 'scale_input'): + elif not hasattr(self, 'scale_input'): self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) return None @@ -411,16 +410,17 @@ if CUBLAS_IS_AVAILABLE: def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, scaled_fp8_activation=False): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) - if scaled_fp8 is not None: - return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) - - if ( + fp8_optimizations = ( fp8_compute and - (fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and - not disable_fast_fp8 - ): + (fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) + and not disable_fast_fp8 + ) + if scaled_fp8 is not None: + return scaled_fp8_ops(fp8_matrix_mult=fp8_optimizations, scale_input=scaled_fp8_activation, override_dtype=scaled_fp8) + + if fp8_optimizations: return fp8_ops if ( diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 54573abb1..d4b2e09c2 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -50,6 +50,7 @@ class BASE: manual_cast_dtype = None custom_operations = None scaled_fp8 = None + scaled_fp8_activation = None optimizations = {"fp8": False} @classmethod