mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
feat(mps): implement native-like Float8 support via LUT dequantization
Add a new MPS-specific operations module to handle Float8 tensor support on Apple Silicon. Since MPS does not natively support Float8 dtypes, this implementation uses a uint8 storage strategy combined with a GPU-accelerated Lookup Table (LUT) for efficient dequantization, keeping data on the GPU. - Add comfy/mps_ops.py: Implement cached LUT generation and index-based dequantization for MPS. - Modify comfy/quant_ops.py: Add logic to view Float8 tensors as uint8 when moving to MPS, and route dequantization to mps_ops. - Modify comfy/float.py: Add CPU staging for stochastic rounding to prevent MPS casting errors during quantization. - Modify comfy/quant_ops.py: Add fallback for fp8_linear. Signed-off-by: Macpaul Lin <macpaul@gmail.com>
This commit is contained in:
parent
5943fbf457
commit
ef7b4a717a
@ -55,13 +55,26 @@ def stochastic_rounding(value, dtype, seed=0):
|
||||
if dtype == torch.bfloat16:
|
||||
return value.to(dtype=torch.bfloat16)
|
||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||
generator = torch.Generator(device=value.device)
|
||||
# MPS workaround: perform float8 conversion on CPU
|
||||
target_device = value.device
|
||||
use_cpu_staging = (target_device.type == "mps")
|
||||
|
||||
output_device = "cpu" if use_cpu_staging else target_device
|
||||
output = torch.empty_like(value, dtype=dtype, device=output_device)
|
||||
|
||||
generator = torch.Generator(device=target_device)
|
||||
generator.manual_seed(seed)
|
||||
output = torch.empty_like(value, dtype=dtype)
|
||||
|
||||
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||
for i in range(0, value.shape[0], slice_size):
|
||||
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
||||
res = manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator)
|
||||
if use_cpu_staging:
|
||||
res = res.cpu()
|
||||
output[i:i+slice_size].copy_(res)
|
||||
|
||||
if use_cpu_staging:
|
||||
return output.to(target_device)
|
||||
return output
|
||||
|
||||
return value.to(dtype=dtype)
|
||||
|
||||
77
comfy/mps_ops.py
Normal file
77
comfy/mps_ops.py
Normal file
@ -0,0 +1,77 @@
|
||||
import torch
|
||||
|
||||
_LUT_CACHE = {}
|
||||
|
||||
def get_lut(dtype, device):
|
||||
"""
|
||||
Get or create a lookup table for float8 dequantization on MPS.
|
||||
Returns a Tensor[256] of dtype=torch.float16 on the specified device.
|
||||
"""
|
||||
key = (dtype, device)
|
||||
if key in _LUT_CACHE:
|
||||
return _LUT_CACHE[key]
|
||||
|
||||
# Generate all possible 8-bit values (0-255)
|
||||
# We create them on CPU first as float8, then cast to float16, then move to MPS.
|
||||
# This acts as our decoding table.
|
||||
|
||||
# Create uint8 pattern 0..255
|
||||
byte_pattern = torch.arange(256, dtype=torch.uint8, device="cpu")
|
||||
|
||||
# View as the target float8 type
|
||||
# Note: We must use .view() on a tensor that has the same number of bytes.
|
||||
# We can't view uint8 as float8 directly if standard pytorch doesn't allow it easily,
|
||||
# but we can create the float8 tensor from bytes.
|
||||
|
||||
# Actually, the easiest way to generate the LUT is:
|
||||
# 1. Create bytes 0..255
|
||||
# 2. View as float8 (on CPU, where it is supported)
|
||||
# 3. Convert to float16 (on CPU)
|
||||
# 4. Move float16 LUT to MPS
|
||||
|
||||
try:
|
||||
f8_tensor = byte_pattern.view(dtype)
|
||||
f16_lut = f8_tensor.to(torch.float16)
|
||||
|
||||
# Move to the requested MPS device
|
||||
lut = f16_lut.to(device)
|
||||
_LUT_CACHE[key] = lut
|
||||
return lut
|
||||
except Exception as e:
|
||||
print(f"Failed to create MPS LUT for {dtype}: {e}")
|
||||
# Fallback: return None or raise
|
||||
raise e
|
||||
|
||||
def mps_dequantize(qdata, scale, orig_dtype, float8_dtype):
|
||||
"""
|
||||
Dequantize a uint8 tensor (representing float8 data) using a LUT on MPS.
|
||||
|
||||
Args:
|
||||
qdata: Tensor of shape (...) with dtype=torch.uint8 (on MPS)
|
||||
scale: Tensor (scalar)
|
||||
orig_dtype: The target dtype (e.g. float16)
|
||||
float8_dtype: The original float8 dtype (torch.float8_e4m3fn or torch.float8_e5m2)
|
||||
|
||||
Returns:
|
||||
Tensor of shape (...) with dtype=orig_dtype
|
||||
"""
|
||||
lut = get_lut(float8_dtype, qdata.device)
|
||||
|
||||
# Use index_select or advanced indexing.
|
||||
# Advanced indexing lut[qdata.long()] is generally efficient.
|
||||
# We explicitly cast to long (int64) for indexing.
|
||||
# Note: Flattening might be slightly faster depending on shape, but simple indexing is safest.
|
||||
|
||||
# We want the LUT to be in the target orig_dtype (likely float16 or bfloat16)
|
||||
if lut.dtype != orig_dtype:
|
||||
lut = lut.to(dtype=orig_dtype)
|
||||
|
||||
output = lut[qdata.long()]
|
||||
|
||||
# Apply scale
|
||||
# Scale might need to be cast to orig_dtype too
|
||||
if isinstance(scale, torch.Tensor):
|
||||
scale = scale.to(dtype=orig_dtype)
|
||||
|
||||
output.mul_(scale)
|
||||
return output
|
||||
@ -44,6 +44,7 @@ except ImportError as e:
|
||||
return None
|
||||
|
||||
import comfy.float
|
||||
import comfy.mps_ops
|
||||
|
||||
# ==============================================================================
|
||||
# FP8 Layouts with Comfy-Specific Extensions
|
||||
@ -51,7 +52,13 @@ import comfy.float
|
||||
|
||||
class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||
FP8_DTYPE = None # Must be overridden in subclass
|
||||
|
||||
|
||||
"""
|
||||
Storage format:
|
||||
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
||||
- scale: Scalar tensor (float32) for dequantization
|
||||
- orig_dtype: Original dtype before quantization (for casting back)
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
if cls.FP8_DTYPE is None:
|
||||
@ -83,6 +90,19 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||
params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape)
|
||||
return qdata, params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
if qdata.device.type == "mps":
|
||||
if qdata.dtype == torch.uint8:
|
||||
return comfy.mps_ops.mps_dequantize(qdata, scale, orig_dtype, kwargs.get("mps_float8_dtype", torch.float8_e4m3fn))
|
||||
elif qdata.is_floating_point() and qdata.element_size() == 1:
|
||||
# It is MPS Float8. View as uint8.
|
||||
return comfy.mps_ops.mps_dequantize(qdata.view(torch.uint8), scale, orig_dtype, qdata.dtype)
|
||||
|
||||
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
||||
plain_tensor.mul_(scale)
|
||||
return plain_tensor
|
||||
|
||||
|
||||
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
|
||||
Loading…
Reference in New Issue
Block a user