mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 00:12:33 +08:00
Rename quant dtype parameter
This commit is contained in:
parent
f287d02419
commit
59a2e8c74e
@ -379,7 +379,7 @@ def fp8_linear(self, input):
|
|||||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||||
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
||||||
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype)
|
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
|
||||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
|
|
||||||
if tensor_2d:
|
if tensor_2d:
|
||||||
|
|||||||
@ -347,20 +347,20 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
- orig_dtype: Original dtype before quantization (for casting back)
|
- orig_dtype: Original dtype before quantization (for casting back)
|
||||||
"""
|
"""
|
||||||
@classmethod
|
@classmethod
|
||||||
def quantize(cls, tensor, scale=None, fp8_dtype=torch.float8_e4m3fn):
|
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
|
||||||
orig_dtype = tensor.dtype
|
orig_dtype = tensor.dtype
|
||||||
|
|
||||||
if scale is None:
|
if scale is None:
|
||||||
scale = torch.amax(tensor.abs()) / torch.finfo(fp8_dtype).max
|
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
||||||
|
|
||||||
if not isinstance(scale, torch.Tensor):
|
if not isinstance(scale, torch.Tensor):
|
||||||
scale = torch.tensor(scale)
|
scale = torch.tensor(scale)
|
||||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
lp_amax = torch.finfo(fp8_dtype).max
|
lp_amax = torch.finfo(dtype).max
|
||||||
tensor_scaled = tensor.float() / scale
|
tensor_scaled = tensor.float() / scale
|
||||||
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||||
qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format)
|
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
layout_params = {
|
layout_params = {
|
||||||
'scale': scale,
|
'scale': scale,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user