Refactor to try to lower mem usage. (#11840)

This commit is contained in:
comfyanonymous 2026-01-12 21:01:52 -08:00 committed by GitHub
parent b3c0e4de57
commit 117e7a5853
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -69,26 +69,31 @@ def stochastic_rounding(value, dtype, seed=0):
# TODO: improve this?
def stochastic_float_to_fp4_e2m1(x, generator):
orig_shape = x.shape
sign = torch.signbit(x).to(torch.uint8)
x_abs = x.abs()
exp = torch.floor(torch.log2(x_abs) + 1.0).clamp(0, 3)
exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3)
x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25
x_abs = x.abs()
exp = torch.floor(torch.log2(x_abs) + 1.1925).clamp(0, 3)
x = x.abs()
exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3)
mantissa = torch.where(
exp > 0,
(x_abs / (2.0 ** (exp - 1)) - 1.0) * 2.0,
(x_abs * 2.0)
(x / (2.0 ** (exp - 1)) - 1.0) * 2.0,
(x * 2.0),
out=x
).round().to(torch.uint8)
del x
fp4 = (sign << 3) | (exp.to(torch.uint8) << 1) | mantissa
exp = exp.to(torch.uint8)
fp4 = (sign << 3) | (exp << 1) | mantissa
del sign, exp, mantissa
fp4_flat = fp4.view(-1)
packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2]
return packed.reshape(list(x.shape)[:-1] + [-1])
return packed.reshape(list(orig_shape)[:-1] + [-1])
def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor: