mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 00:12:33 +08:00
Merge 1a1383e4cd into f13cff0be6
This commit is contained in:
commit
7524da5528
@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module):
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", False)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, scaled_fp8_activation=model_config.scaled_fp8_activation)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
|
||||
@ -697,9 +697,9 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
if model_config.scaled_fp8 == torch.float32:
|
||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||
if scaled_fp8_weight.nelement() == 2:
|
||||
model_config.optimizations["fp8"] = False
|
||||
model_config.scaled_fp8_activation = False
|
||||
else:
|
||||
model_config.optimizations["fp8"] = True
|
||||
model_config.scaled_fp8_activation = True
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
20
comfy/ops.py
20
comfy/ops.py
@ -423,8 +423,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
||||
|
||||
if not scale_input:
|
||||
self.scale_input = None
|
||||
|
||||
if not hasattr(self, 'scale_input'):
|
||||
elif not hasattr(self, 'scale_input'):
|
||||
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
return None
|
||||
|
||||
@ -478,16 +477,17 @@ if CUBLAS_IS_AVAILABLE:
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, scaled_fp8_activation=False):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||
|
||||
if (
|
||||
fp8_optimizations = (
|
||||
fp8_compute and
|
||||
(fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and
|
||||
not disable_fast_fp8
|
||||
):
|
||||
(fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast)
|
||||
and not disable_fast_fp8
|
||||
)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_optimizations, scale_input=scaled_fp8_activation, override_dtype=scaled_fp8)
|
||||
|
||||
if fp8_optimizations:
|
||||
return fp8_ops
|
||||
|
||||
if (
|
||||
|
||||
@ -50,6 +50,7 @@ class BASE:
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
scaled_fp8 = None
|
||||
scaled_fp8_activation = None
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
|
||||
Loading…
Reference in New Issue
Block a user