This commit is contained in:
Gæros 2025-10-22 08:41:41 -03:00 committed by GitHub
commit 7524da5528
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 14 additions and 13 deletions

View File

@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, scaled_fp8_activation=model_config.scaled_fp8_activation)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)

View File

@ -697,9 +697,9 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn
if scaled_fp8_weight.nelement() == 2:
model_config.optimizations["fp8"] = False
model_config.scaled_fp8_activation = False
else:
model_config.optimizations["fp8"] = True
model_config.scaled_fp8_activation = True
return model_config

View File

@ -423,8 +423,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
if not scale_input:
self.scale_input = None
if not hasattr(self, 'scale_input'):
elif not hasattr(self, 'scale_input'):
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
return None
@ -478,16 +477,17 @@ if CUBLAS_IS_AVAILABLE:
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, scaled_fp8_activation=False):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
if (
fp8_optimizations = (
fp8_compute and
(fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and
not disable_fast_fp8
):
(fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast)
and not disable_fast_fp8
)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_optimizations, scale_input=scaled_fp8_activation, override_dtype=scaled_fp8)
if fp8_optimizations:
return fp8_ops
if (

View File

@ -50,6 +50,7 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
scaled_fp8 = None
scaled_fp8_activation = None
optimizations = {"fp8": False}
@classmethod