mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-22 01:23:43 +08:00
Support mxfp8
This commit is contained in:
parent
44f1246c89
commit
d59731cf26
16
comfy/ops.py
16
comfy/ops.py
@ -801,6 +801,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
orig_shape=(self.out_features, self.in_features),
|
orig_shape=(self.out_features, self.in_features),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif self.quant_format == "mxfp8":
|
||||||
|
# MXFP8: E8M0 block scales stored as uint8 in safetensors
|
||||||
|
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||||
|
dtype=torch.uint8)
|
||||||
|
|
||||||
|
if block_scale is None:
|
||||||
|
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
||||||
|
|
||||||
|
block_scale = block_scale.view(torch.float8_e8m0fnu)
|
||||||
|
|
||||||
|
params = layout_cls.Params(
|
||||||
|
scale=block_scale,
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=(self.out_features, self.in_features),
|
||||||
|
)
|
||||||
|
|
||||||
elif self.quant_format == "nvfp4":
|
elif self.quant_format == "nvfp4":
|
||||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ try:
|
|||||||
QuantizedLayout,
|
QuantizedLayout,
|
||||||
TensorCoreFP8Layout as _CKFp8Layout,
|
TensorCoreFP8Layout as _CKFp8Layout,
|
||||||
TensorCoreNVFP4Layout as _CKNvfp4Layout,
|
TensorCoreNVFP4Layout as _CKNvfp4Layout,
|
||||||
|
TensorCoreMXFP8Layout,
|
||||||
register_layout_op,
|
register_layout_op,
|
||||||
register_layout_class,
|
register_layout_class,
|
||||||
get_layout_class,
|
get_layout_class,
|
||||||
@ -137,6 +138,7 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
|||||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||||
|
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||||
|
|
||||||
QUANT_ALGOS = {
|
QUANT_ALGOS = {
|
||||||
"float8_e4m3fn": {
|
"float8_e4m3fn": {
|
||||||
@ -155,6 +157,12 @@ QUANT_ALGOS = {
|
|||||||
"comfy_tensor_layout": "TensorCoreNVFP4Layout",
|
"comfy_tensor_layout": "TensorCoreNVFP4Layout",
|
||||||
"group_size": 16,
|
"group_size": 16,
|
||||||
},
|
},
|
||||||
|
"mxfp8": {
|
||||||
|
"storage_t": torch.float8_e4m3fn,
|
||||||
|
"parameters": {"weight_scale", "input_scale"},
|
||||||
|
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
|
||||||
|
"group_size": 32,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -169,6 +177,7 @@ __all__ = [
|
|||||||
"TensorCoreFP8E4M3Layout",
|
"TensorCoreFP8E4M3Layout",
|
||||||
"TensorCoreFP8E5M2Layout",
|
"TensorCoreFP8E5M2Layout",
|
||||||
"TensorCoreNVFP4Layout",
|
"TensorCoreNVFP4Layout",
|
||||||
|
"TensorCoreMXFP8Layout",
|
||||||
"QUANT_ALGOS",
|
"QUANT_ALGOS",
|
||||||
"register_layout_op",
|
"register_layout_op",
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user