Fix fp8 fast issue. (#11688)

This commit is contained in:
comfyanonymous 2026-01-06 22:39:06 -08:00 committed by GitHub
parent 79e94544bd
commit b7d7cc1d49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -427,12 +427,12 @@ def fp8_linear(self, input):
input = torch.clamp(input, min=-448, max=448, out=input)
input_fp8 = input.to(dtype).contiguous()
layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
quantized_input = QuantizedTensor(input_fp8, TensorCoreFP8Layout, layout_params_input)
quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
uncast_bias_weight(self, w, bias, offload_stream)