Rename quant dtype parameter

This commit is contained in:
lspindler 2025-10-28 07:33:42 +01:00
parent 59a2e8c74e
commit 135d3025ea

View File

@ -122,7 +122,7 @@ class TestTensorCoreFP8Layout(unittest.TestCase):
qdata, layout_params = TensorCoreFP8Layout.quantize( qdata, layout_params = TensorCoreFP8Layout.quantize(
float_tensor, float_tensor,
scale=scale, scale=scale,
fp8_dtype=torch.float8_e4m3fn dtype=torch.float8_e4m3fn
) )
self.assertEqual(qdata.dtype, torch.float8_e4m3fn) self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
@ -139,7 +139,7 @@ class TestTensorCoreFP8Layout(unittest.TestCase):
qdata, layout_params = TensorCoreFP8Layout.quantize( qdata, layout_params = TensorCoreFP8Layout.quantize(
float_tensor, float_tensor,
scale=scale, scale=scale,
fp8_dtype=torch.float8_e4m3fn dtype=torch.float8_e4m3fn
) )
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)