mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 00:30:55 +08:00
Refactor to try to lower mem usage. (#11840)
This commit is contained in:
parent
b3c0e4de57
commit
117e7a5853
@ -69,26 +69,31 @@ def stochastic_rounding(value, dtype, seed=0):
|
||||
|
||||
# TODO: improve this?
|
||||
def stochastic_float_to_fp4_e2m1(x, generator):
|
||||
orig_shape = x.shape
|
||||
sign = torch.signbit(x).to(torch.uint8)
|
||||
x_abs = x.abs()
|
||||
|
||||
exp = torch.floor(torch.log2(x_abs) + 1.0).clamp(0, 3)
|
||||
exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3)
|
||||
x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25
|
||||
|
||||
x_abs = x.abs()
|
||||
exp = torch.floor(torch.log2(x_abs) + 1.1925).clamp(0, 3)
|
||||
x = x.abs()
|
||||
exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3)
|
||||
|
||||
mantissa = torch.where(
|
||||
exp > 0,
|
||||
(x_abs / (2.0 ** (exp - 1)) - 1.0) * 2.0,
|
||||
(x_abs * 2.0)
|
||||
(x / (2.0 ** (exp - 1)) - 1.0) * 2.0,
|
||||
(x * 2.0),
|
||||
out=x
|
||||
).round().to(torch.uint8)
|
||||
del x
|
||||
|
||||
fp4 = (sign << 3) | (exp.to(torch.uint8) << 1) | mantissa
|
||||
exp = exp.to(torch.uint8)
|
||||
|
||||
fp4 = (sign << 3) | (exp << 1) | mantissa
|
||||
del sign, exp, mantissa
|
||||
|
||||
fp4_flat = fp4.view(-1)
|
||||
packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2]
|
||||
return packed.reshape(list(x.shape)[:-1] + [-1])
|
||||
return packed.reshape(list(orig_shape)[:-1] + [-1])
|
||||
|
||||
|
||||
def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user