mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Add a new MPS-specific operations module to handle Float8 tensor support on Apple Silicon. Since MPS does not natively support Float8 dtypes, this implementation uses a uint8 storage strategy combined with a GPU-accelerated Lookup Table (LUT) for efficient dequantization, keeping data on the GPU. - Add comfy/mps_ops.py: Implement cached LUT generation and index-based dequantization for MPS. - Modify comfy/quant_ops.py: Add logic to view Float8 tensors as uint8 when moving to MPS, and route dequantization to mps_ops. - Modify comfy/float.py: Add CPU staging for stochastic rounding to prevent MPS casting errors during quantization. - Modify comfy/quant_ops.py: Add fallback for fp8_linear. Signed-off-by: Macpaul Lin <macpaul@gmail.com>
78 lines
2.6 KiB
Python
78 lines
2.6 KiB
Python
import torch
|
|
|
|
_LUT_CACHE = {}
|
|
|
|
def get_lut(dtype, device):
|
|
"""
|
|
Get or create a lookup table for float8 dequantization on MPS.
|
|
Returns a Tensor[256] of dtype=torch.float16 on the specified device.
|
|
"""
|
|
key = (dtype, device)
|
|
if key in _LUT_CACHE:
|
|
return _LUT_CACHE[key]
|
|
|
|
# Generate all possible 8-bit values (0-255)
|
|
# We create them on CPU first as float8, then cast to float16, then move to MPS.
|
|
# This acts as our decoding table.
|
|
|
|
# Create uint8 pattern 0..255
|
|
byte_pattern = torch.arange(256, dtype=torch.uint8, device="cpu")
|
|
|
|
# View as the target float8 type
|
|
# Note: We must use .view() on a tensor that has the same number of bytes.
|
|
# We can't view uint8 as float8 directly if standard pytorch doesn't allow it easily,
|
|
# but we can create the float8 tensor from bytes.
|
|
|
|
# Actually, the easiest way to generate the LUT is:
|
|
# 1. Create bytes 0..255
|
|
# 2. View as float8 (on CPU, where it is supported)
|
|
# 3. Convert to float16 (on CPU)
|
|
# 4. Move float16 LUT to MPS
|
|
|
|
try:
|
|
f8_tensor = byte_pattern.view(dtype)
|
|
f16_lut = f8_tensor.to(torch.float16)
|
|
|
|
# Move to the requested MPS device
|
|
lut = f16_lut.to(device)
|
|
_LUT_CACHE[key] = lut
|
|
return lut
|
|
except Exception as e:
|
|
print(f"Failed to create MPS LUT for {dtype}: {e}")
|
|
# Fallback: return None or raise
|
|
raise e
|
|
|
|
def mps_dequantize(qdata, scale, orig_dtype, float8_dtype):
|
|
"""
|
|
Dequantize a uint8 tensor (representing float8 data) using a LUT on MPS.
|
|
|
|
Args:
|
|
qdata: Tensor of shape (...) with dtype=torch.uint8 (on MPS)
|
|
scale: Tensor (scalar)
|
|
orig_dtype: The target dtype (e.g. float16)
|
|
float8_dtype: The original float8 dtype (torch.float8_e4m3fn or torch.float8_e5m2)
|
|
|
|
Returns:
|
|
Tensor of shape (...) with dtype=orig_dtype
|
|
"""
|
|
lut = get_lut(float8_dtype, qdata.device)
|
|
|
|
# Use index_select or advanced indexing.
|
|
# Advanced indexing lut[qdata.long()] is generally efficient.
|
|
# We explicitly cast to long (int64) for indexing.
|
|
# Note: Flattening might be slightly faster depending on shape, but simple indexing is safest.
|
|
|
|
# We want the LUT to be in the target orig_dtype (likely float16 or bfloat16)
|
|
if lut.dtype != orig_dtype:
|
|
lut = lut.to(dtype=orig_dtype)
|
|
|
|
output = lut[qdata.long()]
|
|
|
|
# Apply scale
|
|
# Scale might need to be cast to orig_dtype too
|
|
if isinstance(scale, torch.Tensor):
|
|
scale = scale.to(dtype=orig_dtype)
|
|
|
|
output.mul_(scale)
|
|
return output
|