diff --git a/comfy/float.py b/comfy/float.py index e638b1ff7..c806af76b 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -69,26 +69,31 @@ def stochastic_rounding(value, dtype, seed=0): # TODO: improve this? def stochastic_float_to_fp4_e2m1(x, generator): + orig_shape = x.shape sign = torch.signbit(x).to(torch.uint8) - x_abs = x.abs() - exp = torch.floor(torch.log2(x_abs) + 1.0).clamp(0, 3) + exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3) x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25 - x_abs = x.abs() - exp = torch.floor(torch.log2(x_abs) + 1.1925).clamp(0, 3) + x = x.abs() + exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3) mantissa = torch.where( exp > 0, - (x_abs / (2.0 ** (exp - 1)) - 1.0) * 2.0, - (x_abs * 2.0) + (x / (2.0 ** (exp - 1)) - 1.0) * 2.0, + (x * 2.0), + out=x ).round().to(torch.uint8) + del x - fp4 = (sign << 3) | (exp.to(torch.uint8) << 1) | mantissa + exp = exp.to(torch.uint8) + + fp4 = (sign << 3) | (exp << 1) | mantissa + del sign, exp, mantissa fp4_flat = fp4.view(-1) packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2] - return packed.reshape(list(x.shape)[:-1] + [-1]) + return packed.reshape(list(orig_shape)[:-1] + [-1]) def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor: