mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
fp8: model_config flag for scaled activation
This commit is contained in:
parent
5a87757ef9
commit
1a1383e4cd
@ -116,7 +116,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
fp8 = model_config.optimizations.get("fp8", False)
|
fp8 = model_config.optimizations.get("fp8", False)
|
||||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, scaled_fp8_activation=model_config.scaled_fp8_activation)
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
|
|||||||
@ -556,9 +556,9 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
if model_config.scaled_fp8 == torch.float32:
|
if model_config.scaled_fp8 == torch.float32:
|
||||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||||
if scaled_fp8_weight.nelement() == 2:
|
if scaled_fp8_weight.nelement() == 2:
|
||||||
model_config.optimizations["fp8"] = False
|
model_config.scaled_fp8_activation = False
|
||||||
else:
|
else:
|
||||||
model_config.optimizations["fp8"] = True
|
model_config.scaled_fp8_activation = True
|
||||||
|
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
|||||||
20
comfy/ops.py
20
comfy/ops.py
@ -358,8 +358,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
|||||||
|
|
||||||
if not scale_input:
|
if not scale_input:
|
||||||
self.scale_input = None
|
self.scale_input = None
|
||||||
|
elif not hasattr(self, 'scale_input'):
|
||||||
if not hasattr(self, 'scale_input'):
|
|
||||||
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -411,16 +410,17 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, scaled_fp8_activation=False):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
if scaled_fp8 is not None:
|
fp8_optimizations = (
|
||||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
|
||||||
|
|
||||||
if (
|
|
||||||
fp8_compute and
|
fp8_compute and
|
||||||
(fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and
|
(fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast)
|
||||||
not disable_fast_fp8
|
and not disable_fast_fp8
|
||||||
):
|
)
|
||||||
|
if scaled_fp8 is not None:
|
||||||
|
return scaled_fp8_ops(fp8_matrix_mult=fp8_optimizations, scale_input=scaled_fp8_activation, override_dtype=scaled_fp8)
|
||||||
|
|
||||||
|
if fp8_optimizations:
|
||||||
return fp8_ops
|
return fp8_ops
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|||||||
@ -50,6 +50,7 @@ class BASE:
|
|||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
custom_operations = None
|
custom_operations = None
|
||||||
scaled_fp8 = None
|
scaled_fp8 = None
|
||||||
|
scaled_fp8_activation = None
|
||||||
optimizations = {"fp8": False}
|
optimizations = {"fp8": False}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user