feat(mps): implement native-like Float8 support via LUT dequantization

Add a new MPS-specific operations module to handle Float8 tensor support
on Apple Silicon. Since MPS does not natively support Float8 dtypes, this
implementation uses a uint8 storage strategy combined with a GPU-accelerated
Lookup Table (LUT) for efficient dequantization, keeping data on the GPU.

- Add comfy/mps_ops.py: Implement cached LUT generation and index-based
  dequantization for MPS.
- Modify comfy/quant_ops.py: Add logic to view Float8 tensors as uint8
  when moving to MPS, and route dequantization to mps_ops.
- Modify comfy/float.py: Add CPU staging for stochastic rounding to
  prevent MPS casting errors during quantization.
- Modify comfy/quant_ops.py: Add fallback for fp8_linear.

Signed-off-by: Macpaul Lin <macpaul@gmail.com>
This commit is contained in:
Macpaul Lin 2026-01-07 01:31:21 +08:00
parent 4f3f9e72a9
commit d8c65eb448
3 changed files with 113 additions and 5 deletions

View File

@ -55,13 +55,26 @@ def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
# MPS workaround: perform float8 conversion on CPU
target_device = value.device
use_cpu_staging = (target_device.type == "mps")
output_device = "cpu" if use_cpu_staging else target_device
output = torch.empty_like(value, dtype=dtype, device=output_device)
generator = torch.Generator(device=target_device)
generator.manual_seed(seed)
output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))
for i in range(0, value.shape[0], slice_size):
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
res = manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator)
if use_cpu_staging:
res = res.cpu()
output[i:i+slice_size].copy_(res)
if use_cpu_staging:
return output.to(target_device)
return output
return value.to(dtype=dtype)

77
comfy/mps_ops.py Normal file
View File

@ -0,0 +1,77 @@
import torch
_LUT_CACHE = {}
def get_lut(dtype, device):
"""
Get or create a lookup table for float8 dequantization on MPS.
Returns a Tensor[256] of dtype=torch.float16 on the specified device.
"""
key = (dtype, device)
if key in _LUT_CACHE:
return _LUT_CACHE[key]
# Generate all possible 8-bit values (0-255)
# We create them on CPU first as float8, then cast to float16, then move to MPS.
# This acts as our decoding table.
# Create uint8 pattern 0..255
byte_pattern = torch.arange(256, dtype=torch.uint8, device="cpu")
# View as the target float8 type
# Note: We must use .view() on a tensor that has the same number of bytes.
# We can't view uint8 as float8 directly if standard pytorch doesn't allow it easily,
# but we can create the float8 tensor from bytes.
# Actually, the easiest way to generate the LUT is:
# 1. Create bytes 0..255
# 2. View as float8 (on CPU, where it is supported)
# 3. Convert to float16 (on CPU)
# 4. Move float16 LUT to MPS
try:
f8_tensor = byte_pattern.view(dtype)
f16_lut = f8_tensor.to(torch.float16)
# Move to the requested MPS device
lut = f16_lut.to(device)
_LUT_CACHE[key] = lut
return lut
except Exception as e:
print(f"Failed to create MPS LUT for {dtype}: {e}")
# Fallback: return None or raise
raise e
def mps_dequantize(qdata, scale, orig_dtype, float8_dtype):
"""
Dequantize a uint8 tensor (representing float8 data) using a LUT on MPS.
Args:
qdata: Tensor of shape (...) with dtype=torch.uint8 (on MPS)
scale: Tensor (scalar)
orig_dtype: The target dtype (e.g. float16)
float8_dtype: The original float8 dtype (torch.float8_e4m3fn or torch.float8_e5m2)
Returns:
Tensor of shape (...) with dtype=orig_dtype
"""
lut = get_lut(float8_dtype, qdata.device)
# Use index_select or advanced indexing.
# Advanced indexing lut[qdata.long()] is generally efficient.
# We explicitly cast to long (int64) for indexing.
# Note: Flattening might be slightly faster depending on shape, but simple indexing is safest.
# We want the LUT to be in the target orig_dtype (likely float16 or bfloat16)
if lut.dtype != orig_dtype:
lut = lut.to(dtype=orig_dtype)
output = lut[qdata.long()]
# Apply scale
# Scale might need to be cast to orig_dtype too
if isinstance(scale, torch.Tensor):
scale = scale.to(dtype=orig_dtype)
output.mul_(scale)
return output

View File

@ -2,6 +2,7 @@ import torch
import logging
from typing import Tuple, Dict
import comfy.float
import comfy.mps_ops
_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
@ -269,8 +270,18 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
# MPS Hack: Convert Float8 to Uint8 before moving if native float8 is unsupported
qdata = qt._qdata
layout_params = qt._layout_params.copy()
if target_device.type == "mps" and qdata.element_size() == 1 and qdata.is_floating_point(): # Catch float8
layout_params["mps_float8_dtype"] = qdata.dtype
qdata = qdata.view(torch.uint8)
new_q_data = qdata.to(device=target_device)
new_params = _move_layout_params_to_device(layout_params, target_device)
if target_dtype is not None:
new_params["orig_dtype"] = target_dtype
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
@ -431,6 +442,13 @@ class TensorCoreFP8Layout(QuantizedLayout):
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
if qdata.device.type == "mps":
if qdata.dtype == torch.uint8:
return comfy.mps_ops.mps_dequantize(qdata, scale, orig_dtype, kwargs.get("mps_float8_dtype", torch.float8_e4m3fn))
elif qdata.is_floating_point() and qdata.element_size() == 1:
# It is MPS Float8. View as uint8.
return comfy.mps_ops.mps_dequantize(qdata.view(torch.uint8), scale, orig_dtype, qdata.dtype)
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
plain_tensor.mul_(scale)
return plain_tensor