diff --git a/comfy/float.py b/comfy/float.py index 521316fd2..848c6ff68 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -55,13 +55,26 @@ def stochastic_rounding(value, dtype, seed=0): if dtype == torch.bfloat16: return value.to(dtype=torch.bfloat16) if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: - generator = torch.Generator(device=value.device) + # MPS workaround: perform float8 conversion on CPU + target_device = value.device + use_cpu_staging = (target_device.type == "mps") + + output_device = "cpu" if use_cpu_staging else target_device + output = torch.empty_like(value, dtype=dtype, device=output_device) + + generator = torch.Generator(device=target_device) generator.manual_seed(seed) - output = torch.empty_like(value, dtype=dtype) + num_slices = max(1, (value.numel() / (4096 * 4096))) slice_size = max(1, round(value.shape[0] / num_slices)) for i in range(0, value.shape[0], slice_size): - output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator)) + res = manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator) + if use_cpu_staging: + res = res.cpu() + output[i:i+slice_size].copy_(res) + + if use_cpu_staging: + return output.to(target_device) return output return value.to(dtype=dtype) diff --git a/comfy/mps_ops.py b/comfy/mps_ops.py new file mode 100644 index 000000000..a5930356b --- /dev/null +++ b/comfy/mps_ops.py @@ -0,0 +1,77 @@ +import torch + +_LUT_CACHE = {} + +def get_lut(dtype, device): + """ + Get or create a lookup table for float8 dequantization on MPS. + Returns a Tensor[256] of dtype=torch.float16 on the specified device. + """ + key = (dtype, device) + if key in _LUT_CACHE: + return _LUT_CACHE[key] + + # Generate all possible 8-bit values (0-255) + # We create them on CPU first as float8, then cast to float16, then move to MPS. + # This acts as our decoding table. + + # Create uint8 pattern 0..255 + byte_pattern = torch.arange(256, dtype=torch.uint8, device="cpu") + + # View as the target float8 type + # Note: We must use .view() on a tensor that has the same number of bytes. + # We can't view uint8 as float8 directly if standard pytorch doesn't allow it easily, + # but we can create the float8 tensor from bytes. + + # Actually, the easiest way to generate the LUT is: + # 1. Create bytes 0..255 + # 2. View as float8 (on CPU, where it is supported) + # 3. Convert to float16 (on CPU) + # 4. Move float16 LUT to MPS + + try: + f8_tensor = byte_pattern.view(dtype) + f16_lut = f8_tensor.to(torch.float16) + + # Move to the requested MPS device + lut = f16_lut.to(device) + _LUT_CACHE[key] = lut + return lut + except Exception as e: + print(f"Failed to create MPS LUT for {dtype}: {e}") + # Fallback: return None or raise + raise e + +def mps_dequantize(qdata, scale, orig_dtype, float8_dtype): + """ + Dequantize a uint8 tensor (representing float8 data) using a LUT on MPS. + + Args: + qdata: Tensor of shape (...) with dtype=torch.uint8 (on MPS) + scale: Tensor (scalar) + orig_dtype: The target dtype (e.g. float16) + float8_dtype: The original float8 dtype (torch.float8_e4m3fn or torch.float8_e5m2) + + Returns: + Tensor of shape (...) with dtype=orig_dtype + """ + lut = get_lut(float8_dtype, qdata.device) + + # Use index_select or advanced indexing. + # Advanced indexing lut[qdata.long()] is generally efficient. + # We explicitly cast to long (int64) for indexing. + # Note: Flattening might be slightly faster depending on shape, but simple indexing is safest. + + # We want the LUT to be in the target orig_dtype (likely float16 or bfloat16) + if lut.dtype != orig_dtype: + lut = lut.to(dtype=orig_dtype) + + output = lut[qdata.long()] + + # Apply scale + # Scale might need to be cast to orig_dtype too + if isinstance(scale, torch.Tensor): + scale = scale.to(dtype=orig_dtype) + + output.mul_(scale) + return output diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 8324be42a..08a8a996d 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -44,6 +44,7 @@ except ImportError as e: return None import comfy.float +import comfy.mps_ops # ============================================================================== # FP8 Layouts with Comfy-Specific Extensions @@ -51,7 +52,13 @@ import comfy.float class _TensorCoreFP8LayoutBase(_CKFp8Layout): FP8_DTYPE = None # Must be overridden in subclass - + + """ + Storage format: + - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) + - scale: Scalar tensor (float32) for dequantization + - orig_dtype: Original dtype before quantization (for casting back) + """ @classmethod def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): if cls.FP8_DTYPE is None: @@ -83,6 +90,19 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout): params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape) return qdata, params + @staticmethod + def dequantize(qdata, scale, orig_dtype, **kwargs): + if qdata.device.type == "mps": + if qdata.dtype == torch.uint8: + return comfy.mps_ops.mps_dequantize(qdata, scale, orig_dtype, kwargs.get("mps_float8_dtype", torch.float8_e4m3fn)) + elif qdata.is_floating_point() and qdata.element_size() == 1: + # It is MPS Float8. View as uint8. + return comfy.mps_ops.mps_dequantize(qdata.view(torch.uint8), scale, orig_dtype, qdata.dtype) + + plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) + plain_tensor.mul_(scale) + return plain_tensor + class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase): FP8_DTYPE = torch.float8_e4m3fn