mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Fix fp8 fast issue. (#11688)
This commit is contained in:
parent
79e94544bd
commit
b7d7cc1d49
@ -427,12 +427,12 @@ def fp8_linear(self, input):
|
|||||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
input_fp8 = input.to(dtype).contiguous()
|
input_fp8 = input.to(dtype).contiguous()
|
||||||
layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
|
layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
|
||||||
quantized_input = QuantizedTensor(input_fp8, TensorCoreFP8Layout, layout_params_input)
|
quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input)
|
||||||
|
|
||||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
|
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
|
||||||
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
|
|
||||||
uncast_bias_weight(self, w, bias, offload_stream)
|
uncast_bias_weight(self, w, bias, offload_stream)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user