mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Add a new MPS-specific operations module to handle Float8 tensor support on Apple Silicon. Since MPS does not natively support Float8 dtypes, this implementation uses a uint8 storage strategy combined with a GPU-accelerated Lookup Table (LUT) for efficient dequantization, keeping data on the GPU. - Add comfy/mps_ops.py: Implement cached LUT generation and index-based dequantization for MPS. - Modify comfy/quant_ops.py: Add logic to view Float8 tensors as uint8 when moving to MPS, and route dequantization to mps_ops. - Modify comfy/float.py: Add CPU staging for stochastic rounding to prevent MPS casting errors during quantization. - Modify comfy/quant_ops.py: Add fallback for fp8_linear. Signed-off-by: Macpaul Lin <macpaul@gmail.com>
81 lines
2.8 KiB
Python
81 lines
2.8 KiB
Python
import torch
|
|
|
|
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
|
mantissa_scaled = torch.where(
|
|
normal_mask,
|
|
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
|
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
|
)
|
|
|
|
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
|
|
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
|
|
|
|
#Not 100% sure about this
|
|
def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
|
if dtype == torch.float8_e4m3fn:
|
|
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
|
elif dtype == torch.float8_e5m2:
|
|
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
|
|
else:
|
|
raise ValueError("Unsupported dtype")
|
|
|
|
x = x.half()
|
|
sign = torch.sign(x)
|
|
abs_x = x.abs()
|
|
sign = torch.where(abs_x == 0, 0, sign)
|
|
|
|
# Combine exponent calculation and clamping
|
|
exponent = torch.clamp(
|
|
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
|
|
0, 2**EXPONENT_BITS - 1
|
|
)
|
|
|
|
# Combine mantissa calculation and rounding
|
|
normal_mask = ~(exponent == 0)
|
|
|
|
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
|
|
|
|
sign *= torch.where(
|
|
normal_mask,
|
|
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
|
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
|
)
|
|
|
|
inf = torch.finfo(dtype)
|
|
torch.clamp(sign, min=inf.min, max=inf.max, out=sign)
|
|
return sign
|
|
|
|
|
|
|
|
def stochastic_rounding(value, dtype, seed=0):
|
|
if dtype == torch.float32:
|
|
return value.to(dtype=torch.float32)
|
|
if dtype == torch.float16:
|
|
return value.to(dtype=torch.float16)
|
|
if dtype == torch.bfloat16:
|
|
return value.to(dtype=torch.bfloat16)
|
|
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
|
# MPS workaround: perform float8 conversion on CPU
|
|
target_device = value.device
|
|
use_cpu_staging = (target_device.type == "mps")
|
|
|
|
output_device = "cpu" if use_cpu_staging else target_device
|
|
output = torch.empty_like(value, dtype=dtype, device=output_device)
|
|
|
|
generator = torch.Generator(device=target_device)
|
|
generator.manual_seed(seed)
|
|
|
|
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
|
slice_size = max(1, round(value.shape[0] / num_slices))
|
|
for i in range(0, value.shape[0], slice_size):
|
|
res = manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator)
|
|
if use_cpu_staging:
|
|
res = res.cpu()
|
|
output[i:i+slice_size].copy_(res)
|
|
|
|
if use_cpu_staging:
|
|
return output.to(target_device)
|
|
return output
|
|
|
|
return value.to(dtype=dtype)
|