mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-16 06:37:41 +08:00
feat: Support mxfp8 (#12907)
This commit is contained in:
parent
e0982a7174
commit
1c5db7397d
@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
|
|||||||
output_block[i:i + slice_size].copy_(block)
|
output_block[i:i + slice_size].copy_(block)
|
||||||
|
|
||||||
return output_fp4, to_blocked(output_block, flatten=False)
|
return output_fp4, to_blocked(output_block, flatten=False)
|
||||||
|
|
||||||
|
|
||||||
|
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
|
||||||
|
def roundup(x_val, multiple):
|
||||||
|
return ((x_val + multiple - 1) // multiple) * multiple
|
||||||
|
|
||||||
|
if pad_32x:
|
||||||
|
rows, cols = x.shape
|
||||||
|
padded_rows = roundup(rows, 32)
|
||||||
|
padded_cols = roundup(cols, 32)
|
||||||
|
if padded_rows != rows or padded_cols != cols:
|
||||||
|
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||||
|
|
||||||
|
F8_E4M3_MAX = 448.0
|
||||||
|
E8M0_BIAS = 127
|
||||||
|
BLOCK_SIZE = 32
|
||||||
|
|
||||||
|
rows, cols = x.shape
|
||||||
|
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
|
||||||
|
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
|
||||||
|
|
||||||
|
# E8M0 block scales (power-of-2 exponents)
|
||||||
|
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
|
||||||
|
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
|
||||||
|
block_scales_e8m0 = exp_biased.to(torch.uint8)
|
||||||
|
|
||||||
|
zero_mask = (max_abs == 0)
|
||||||
|
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
|
||||||
|
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
|
||||||
|
|
||||||
|
# Scale per-block then stochastic round
|
||||||
|
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
|
||||||
|
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
|
||||||
|
|
||||||
|
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
|
||||||
|
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)
|
||||||
|
|||||||
@ -1712,6 +1712,19 @@ def supports_nvfp4_compute(device=None):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def supports_mxfp8_compute(device=None):
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if torch_version_numeric < (2, 10):
|
||||||
|
return False
|
||||||
|
|
||||||
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
if props.major < 10:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def extended_fp16_support():
|
def extended_fp16_support():
|
||||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||||
if torch_version_numeric < (2, 7):
|
if torch_version_numeric < (2, 7):
|
||||||
|
|||||||
19
comfy/ops.py
19
comfy/ops.py
@ -857,6 +857,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
orig_shape=(self.out_features, self.in_features),
|
orig_shape=(self.out_features, self.in_features),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif self.quant_format == "mxfp8":
|
||||||
|
# MXFP8: E8M0 block scales stored as uint8 in safetensors
|
||||||
|
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||||
|
dtype=torch.uint8)
|
||||||
|
|
||||||
|
if block_scale is None:
|
||||||
|
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
||||||
|
|
||||||
|
block_scale = block_scale.view(torch.float8_e8m0fnu)
|
||||||
|
|
||||||
|
params = layout_cls.Params(
|
||||||
|
scale=block_scale,
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=(self.out_features, self.in_features),
|
||||||
|
)
|
||||||
|
|
||||||
elif self.quant_format == "nvfp4":
|
elif self.quant_format == "nvfp4":
|
||||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||||
@ -1006,12 +1022,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||||
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
||||||
|
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
|
||||||
|
|
||||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||||
logging.info("Using mixed precision operations")
|
logging.info("Using mixed precision operations")
|
||||||
disabled = set()
|
disabled = set()
|
||||||
if not nvfp4_compute:
|
if not nvfp4_compute:
|
||||||
disabled.add("nvfp4")
|
disabled.add("nvfp4")
|
||||||
|
if not mxfp8_compute:
|
||||||
|
disabled.add("mxfp8")
|
||||||
if not fp8_compute:
|
if not fp8_compute:
|
||||||
disabled.add("float8_e4m3fn")
|
disabled.add("float8_e4m3fn")
|
||||||
disabled.add("float8_e5m2")
|
disabled.add("float8_e5m2")
|
||||||
|
|||||||
@ -43,6 +43,18 @@ except ImportError as e:
|
|||||||
def get_layout_class(name):
|
def get_layout_class(name):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
_CK_MXFP8_AVAILABLE = False
|
||||||
|
if _CK_AVAILABLE:
|
||||||
|
try:
|
||||||
|
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
|
||||||
|
_CK_MXFP8_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
|
||||||
|
|
||||||
|
if not _CK_MXFP8_AVAILABLE:
|
||||||
|
class _CKMxfp8Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
import comfy.float
|
import comfy.float
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
@ -84,6 +96,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
|||||||
return qdata, params
|
return qdata, params
|
||||||
|
|
||||||
|
|
||||||
|
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||||
|
if tensor.dim() != 2:
|
||||||
|
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
|
||||||
|
|
||||||
|
orig_dtype = tensor.dtype
|
||||||
|
orig_shape = tuple(tensor.shape)
|
||||||
|
|
||||||
|
padded_shape = cls.get_padded_shape(orig_shape)
|
||||||
|
needs_padding = padded_shape != orig_shape
|
||||||
|
|
||||||
|
if stochastic_rounding > 0:
|
||||||
|
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
|
||||||
|
else:
|
||||||
|
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
|
||||||
|
|
||||||
|
params = cls.Params(
|
||||||
|
scale=block_scale,
|
||||||
|
orig_dtype=orig_dtype,
|
||||||
|
orig_shape=orig_shape,
|
||||||
|
)
|
||||||
|
return qdata, params
|
||||||
|
|
||||||
|
|
||||||
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
||||||
@classmethod
|
@classmethod
|
||||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||||
@ -137,6 +174,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
|||||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||||
|
if _CK_MXFP8_AVAILABLE:
|
||||||
|
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||||
|
|
||||||
QUANT_ALGOS = {
|
QUANT_ALGOS = {
|
||||||
"float8_e4m3fn": {
|
"float8_e4m3fn": {
|
||||||
@ -157,6 +196,14 @@ QUANT_ALGOS = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _CK_MXFP8_AVAILABLE:
|
||||||
|
QUANT_ALGOS["mxfp8"] = {
|
||||||
|
"storage_t": torch.float8_e4m3fn,
|
||||||
|
"parameters": {"weight_scale", "input_scale"},
|
||||||
|
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
|
||||||
|
"group_size": 32,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Re-exports for backward compatibility
|
# Re-exports for backward compatibility
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user