mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 22:30:50 +08:00
Skip fp4 matrix mult on devices that don't support it. (#11677)
This commit is contained in:
parent
6e9ee55cdd
commit
2c03884f5f
@ -1504,6 +1504,16 @@ def supports_fp8_compute(device=None):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def supports_nvfp4_compute(device=None):
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
if props.major < 10:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def extended_fp16_support():
|
def extended_fp16_support():
|
||||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||||
if torch_version_numeric < (2, 7):
|
if torch_version_numeric < (2, 7):
|
||||||
|
|||||||
21
comfy/ops.py
21
comfy/ops.py
@ -493,11 +493,12 @@ from .quant_ops import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||||
class MixedPrecisionOps(manual_cast):
|
class MixedPrecisionOps(manual_cast):
|
||||||
_quant_config = quant_config
|
_quant_config = quant_config
|
||||||
_compute_dtype = compute_dtype
|
_compute_dtype = compute_dtype
|
||||||
_full_precision_mm = full_precision_mm
|
_full_precision_mm = full_precision_mm
|
||||||
|
_disabled = disabled
|
||||||
|
|
||||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -522,6 +523,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
self.tensor_class = None
|
self.tensor_class = None
|
||||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||||
|
self._full_precision_mm_config = False
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
@ -556,8 +558,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||||
else:
|
else:
|
||||||
self.quant_format = layer_conf.get("format", None)
|
self.quant_format = layer_conf.get("format", None)
|
||||||
|
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||||
if not self._full_precision_mm:
|
if not self._full_precision_mm:
|
||||||
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
|
self._full_precision_mm = self._full_precision_mm_config
|
||||||
|
|
||||||
|
if self.quant_format in MixedPrecisionOps._disabled:
|
||||||
|
self._full_precision_mm = True
|
||||||
|
|
||||||
if self.quant_format is None:
|
if self.quant_format is None:
|
||||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||||
@ -630,7 +636,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
|
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
|
||||||
|
|
||||||
quant_conf = {"format": self.quant_format}
|
quant_conf = {"format": self.quant_format}
|
||||||
if self._full_precision_mm:
|
if self._full_precision_mm_config:
|
||||||
quant_conf["full_precision_matrix_mult"] = True
|
quant_conf["full_precision_matrix_mult"] = True
|
||||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||||
return sd
|
return sd
|
||||||
@ -711,10 +717,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||||
|
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
||||||
|
|
||||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||||
logging.info("Using mixed precision operations")
|
logging.info("Using mixed precision operations")
|
||||||
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
disabled = set()
|
||||||
|
if not nvfp4_compute:
|
||||||
|
disabled.add("nvfp4")
|
||||||
|
if not fp8_compute:
|
||||||
|
disabled.add("float8_e4m3fn")
|
||||||
|
disabled.add("float8_e5m2")
|
||||||
|
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
fp8_compute and
|
fp8_compute and
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user