mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 03:23:00 +08:00
add more tests by comparing with manual torch implementation add perf benchmarks fix errors caused by merging default no output quant fix unittest
70 lines
2.5 KiB
Python
70 lines
2.5 KiB
Python
import torch
|
|
|
|
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
|
mantissa_scaled = torch.where(
|
|
normal_mask,
|
|
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
|
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
|
)
|
|
|
|
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
|
|
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
|
|
|
|
#Not 100% sure about this
|
|
def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
|
if dtype == torch.float8_e4m3fn:
|
|
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
|
elif dtype == torch.float8_e5m2:
|
|
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
|
|
else:
|
|
raise ValueError("Unsupported dtype")
|
|
|
|
x = x.half()
|
|
sign = torch.sign(x)
|
|
abs_x = x.abs()
|
|
sign = torch.where(abs_x == 0, 0, sign)
|
|
|
|
# Combine exponent calculation and clamping
|
|
exponent = torch.clamp(
|
|
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
|
|
0, 2**EXPONENT_BITS - 1
|
|
)
|
|
|
|
# Combine mantissa calculation and rounding
|
|
normal_mask = ~(exponent == 0)
|
|
|
|
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
|
|
|
|
sign *= torch.where(
|
|
normal_mask,
|
|
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
|
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
|
)
|
|
|
|
inf = torch.finfo(dtype)
|
|
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
|
|
return sign
|
|
|
|
|
|
|
|
def stochastic_rounding(value, dtype, seed=0):
|
|
if dtype == torch.float32:
|
|
return value.to(dtype=torch.float32)
|
|
if dtype == torch.float16:
|
|
return value.to(dtype=torch.float16)
|
|
if dtype == torch.bfloat16:
|
|
return value.to(dtype=torch.bfloat16)
|
|
if dtype == torch.int8:
|
|
return value.to(dtype=torch.int8)
|
|
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
|
generator = torch.Generator(device=value.device)
|
|
generator.manual_seed(seed)
|
|
output = torch.empty_like(value, dtype=dtype)
|
|
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
|
slice_size = max(1, round(value.shape[0] / num_slices))
|
|
for i in range(0, value.shape[0], slice_size):
|
|
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
|
return output
|
|
|
|
return value.to(dtype=dtype)
|