diff --git a/comfy/float.py b/comfy/float.py index 88c47cd80..c6e943805 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -55,6 +55,11 @@ def stochastic_rounding(value, dtype, seed=0): if dtype == torch.bfloat16: return value.to(dtype=torch.bfloat16) if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: + # MPS does not support FP8 dtypes — perform rounding on CPU and return the result there. + on_mps = value.device.type == "mps" + if on_mps: + value = value.cpu() + generator = torch.Generator(device=value.device) generator.manual_seed(seed) output = torch.empty_like(value, dtype=dtype) @@ -159,6 +164,12 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): """Round up x to the nearest multiple.""" return ((x + multiple - 1) // multiple) * multiple + # MPS does not support FP8 dtypes used for block scales — perform on CPU. + on_mps = x.device.type == "mps" + if on_mps: + x = x.cpu() + per_tensor_scale = per_tensor_scale.cpu() if isinstance(per_tensor_scale, torch.Tensor) else per_tensor_scale + generator = torch.Generator(device=x.device) generator.manual_seed(seed) @@ -179,6 +190,12 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed= """Round up x to the nearest multiple.""" return ((x + multiple - 1) // multiple) * multiple + # MPS does not support FP8 dtypes used for block scales — perform on CPU. + on_mps = x.device.type == "mps" + if on_mps: + x = x.cpu() + per_tensor_scale = per_tensor_scale.cpu() if isinstance(per_tensor_scale, torch.Tensor) else per_tensor_scale + orig_shape = x.shape # Handle padding diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 15a4f457b..1682d1049 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -71,6 +71,12 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout): if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32) + # MPS does not support FP8 dtypes — move to CPU for quantization. + on_mps = tensor.device.type == "mps" + if on_mps: + tensor = tensor.cpu() + scale = scale.cpu() + if stochastic_rounding > 0: if inplace_ops: tensor *= (1.0 / scale).to(tensor.dtype) diff --git a/tests/test_fp8_mps.py b/tests/test_fp8_mps.py new file mode 100644 index 000000000..1f4c58b7b --- /dev/null +++ b/tests/test_fp8_mps.py @@ -0,0 +1,178 @@ +""" +Tests for FP8 quantization on MPS (Apple Silicon) devices. + +MPS does not natively support float8_e4m3fn or float8_e5m2 dtypes. +These tests verify that: + 1. FP8 operations correctly fall back to CPU when on MPS. + 2. The round-trip (quantize on CPU -> result on original device) is numerically sound. + 3. No "Placeholder storage has not been allocated on MPS device!" errors occur. +""" +import sys +import pytest +import torch + +# Skip the entire module if MPS is not available +pytestmark = pytest.mark.skipif( + not torch.backends.mps.is_available(), + reason="MPS backend not available" +) + +# ── helpers ────────────────────────────────────────────────────────────────── + +def _make_mps_tensor(shape=(256, 256), dtype=torch.float32): + return torch.randn(shape, device="mps", dtype=dtype) + + +# ── Tests for comfy.float ──────────────────────────────────────────────────── + +class TestStochasticRoundingMPS: + """Tests for comfy.float.stochastic_rounding on MPS device.""" + + def test_stochastic_rounding_fp8_e4m3fn_on_mps(self): + """stochastic_rounding must not crash when input is on MPS and target dtype is float8_e4m3fn.""" + sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI") + import comfy.float + + x = _make_mps_tensor((64, 64), dtype=torch.float32) + result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42) + + assert result.dtype == torch.float8_e4m3fn + assert result.shape == x.shape + + def test_stochastic_rounding_fp8_e5m2_on_mps(self): + """stochastic_rounding must not crash when input is on MPS and target dtype is float8_e5m2.""" + sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI") + import comfy.float + + x = _make_mps_tensor((64, 64), dtype=torch.float32) + result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e5m2, seed=42) + + assert result.dtype == torch.float8_e5m2 + assert result.shape == x.shape + + def test_stochastic_rounding_fp8_result_on_cpu(self): + """Result of FP8 rounding from MPS input should be on CPU (since MPS can't hold FP8).""" + sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI") + import comfy.float + + x = _make_mps_tensor((32, 32), dtype=torch.float32) + result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42) + + # FP8 tensors cannot live on MPS, so result must be on CPU + assert result.device.type == "cpu" + + def test_stochastic_rounding_non_fp8_still_works(self): + """Non-FP8 dtypes on MPS must still work as before (no regression).""" + sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI") + import comfy.float + + x = _make_mps_tensor((32, 32), dtype=torch.float32) + + r16 = comfy.float.stochastic_rounding(x, dtype=torch.float16, seed=0) + assert r16.dtype == torch.float16 + assert r16.device.type == "mps" + + rbf16 = comfy.float.stochastic_rounding(x, dtype=torch.bfloat16, seed=0) + assert rbf16.dtype == torch.bfloat16 + assert rbf16.device.type == "mps" + + def test_stochastic_rounding_fp8_numerical_sanity(self): + """FP8 round-trip (float32 -> fp8 -> float32) should have bounded error.""" + sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI") + import comfy.float + + x = torch.randn(128, 128, device="mps", dtype=torch.float32) + x_clamped = torch.clamp(x, min=-448, max=448) # FP8 e4m3fn range + + fp8 = comfy.float.stochastic_rounding(x_clamped, dtype=torch.float8_e4m3fn, seed=123) + # Convert back to float32 for comparison + reconstructed = fp8.to(torch.float32) + + # Max relative error should be bounded (FP8 e4m3fn has ~0.125 relative precision) + x_cpu = x_clamped.cpu() + max_abs_err = (reconstructed - x_cpu).abs().max().item() + # FP8 e4m3fn max value is 448, min subnormal ~0.001953 + # For random normal data, error should be well under 1.0 + assert max_abs_err < 2.0, f"FP8 round-trip error too large: {max_abs_err}" + + +class TestManualStochasticRoundMPS: + """Tests for comfy.float.manual_stochastic_round_to_float8 on MPS device.""" + + def test_manual_round_fp8_on_mps_tensor_fails_without_fix(self): + """Calling manual_stochastic_round_to_float8 with MPS generator should fail or be handled.""" + sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI") + import comfy.float + + x = _make_mps_tensor((16, 16), dtype=torch.float32) + # The generator device matters - this is the root cause of the second error + # (Placeholder storage has not been allocated on MPS device!) + # After fix, stochastic_rounding should handle this internally + result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42) + assert result.dtype == torch.float8_e4m3fn + + +class TestNVFP4StochasticRoundMPS: + """Tests for NVFP4 stochastic rounding on MPS - also creates FP8 tensors internally.""" + + def test_nvfp4_stochastic_round_on_mps(self): + """stochastic_round_quantize_nvfp4 creates FP8 block scales internally.""" + sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI") + import comfy.float + + # NVFP4 requires 2D input with dimensions divisible by 16 + x = torch.randn(32, 32, device="mps", dtype=torch.float32) + scale = torch.tensor(1.0, device="mps", dtype=torch.float32) + + # This should not crash - internally creates float8_e4m3fn block scales + qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4( + x, scale, pad_16x=False, seed=42 + ) + assert qdata.dtype == torch.uint8 + + +# ── Tests for comfy.quant_ops (integration) ────────────────────────────────── + +class TestQuantOpsMPS: + """Tests for the quantization ops layer that calls into comfy.float.""" + + def test_fp8_layout_quantize_on_mps(self): + """TensorCoreFP8E4M3Layout.quantize must work with MPS tensors.""" + sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI") + from comfy.quant_ops import TensorCoreFP8E4M3Layout + + x = _make_mps_tensor((64, 64), dtype=torch.bfloat16) + qdata, params = TensorCoreFP8E4M3Layout.quantize( + x, scale="recalculate", stochastic_rounding=42 + ) + + assert qdata.dtype == torch.float8_e4m3fn + assert params.orig_dtype == torch.bfloat16 + + def test_fp8_layout_quantize_without_stochastic_on_mps(self): + """TensorCoreFP8E4M3Layout.quantize with stochastic_rounding=0 uses ck.quantize_per_tensor_fp8.""" + sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI") + from comfy.quant_ops import TensorCoreFP8E4M3Layout + + x = _make_mps_tensor((64, 64), dtype=torch.bfloat16) + qdata, params = TensorCoreFP8E4M3Layout.quantize( + x, scale="recalculate", stochastic_rounding=0 + ) + + assert qdata.dtype == torch.float8_e4m3fn + + def test_fp8_e5m2_layout_quantize_on_mps(self): + """TensorCoreFP8E5M2Layout.quantize must work with MPS tensors.""" + sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI") + from comfy.quant_ops import TensorCoreFP8E5M2Layout + + x = _make_mps_tensor((64, 64), dtype=torch.float32) + qdata, params = TensorCoreFP8E5M2Layout.quantize( + x, scale="recalculate", stochastic_rounding=42 + ) + + assert qdata.dtype == torch.float8_e5m2 + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"])