fix: remove hardcoded local paths from MPS FP8 tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Tashdid Khan 2026-02-09 21:27:35 -05:00
parent edd44a6874
commit 45a2363e6a

View File

@ -7,10 +7,12 @@ These tests verify that:
2. The round-trip (quantize on CPU -> result on original device) is numerically sound. 2. The round-trip (quantize on CPU -> result on original device) is numerically sound.
3. No "Placeholder storage has not been allocated on MPS device!" errors occur. 3. No "Placeholder storage has not been allocated on MPS device!" errors occur.
""" """
import sys
import pytest import pytest
import torch import torch
import comfy.float
from comfy.quant_ops import TensorCoreFP8E4M3Layout, TensorCoreFP8E5M2Layout
# Skip the entire module if MPS is not available # Skip the entire module if MPS is not available
pytestmark = pytest.mark.skipif( pytestmark = pytest.mark.skipif(
not torch.backends.mps.is_available(), not torch.backends.mps.is_available(),
@ -30,9 +32,6 @@ class TestStochasticRoundingMPS:
def test_stochastic_rounding_fp8_e4m3fn_on_mps(self): def test_stochastic_rounding_fp8_e4m3fn_on_mps(self):
"""stochastic_rounding must not crash when input is on MPS and target dtype is float8_e4m3fn.""" """stochastic_rounding must not crash when input is on MPS and target dtype is float8_e4m3fn."""
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
import comfy.float
x = _make_mps_tensor((64, 64), dtype=torch.float32) x = _make_mps_tensor((64, 64), dtype=torch.float32)
result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42) result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42)
@ -41,9 +40,6 @@ class TestStochasticRoundingMPS:
def test_stochastic_rounding_fp8_e5m2_on_mps(self): def test_stochastic_rounding_fp8_e5m2_on_mps(self):
"""stochastic_rounding must not crash when input is on MPS and target dtype is float8_e5m2.""" """stochastic_rounding must not crash when input is on MPS and target dtype is float8_e5m2."""
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
import comfy.float
x = _make_mps_tensor((64, 64), dtype=torch.float32) x = _make_mps_tensor((64, 64), dtype=torch.float32)
result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e5m2, seed=42) result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e5m2, seed=42)
@ -52,9 +48,6 @@ class TestStochasticRoundingMPS:
def test_stochastic_rounding_fp8_result_on_cpu(self): def test_stochastic_rounding_fp8_result_on_cpu(self):
"""Result of FP8 rounding from MPS input should be on CPU (since MPS can't hold FP8).""" """Result of FP8 rounding from MPS input should be on CPU (since MPS can't hold FP8)."""
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
import comfy.float
x = _make_mps_tensor((32, 32), dtype=torch.float32) x = _make_mps_tensor((32, 32), dtype=torch.float32)
result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42) result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42)
@ -63,9 +56,6 @@ class TestStochasticRoundingMPS:
def test_stochastic_rounding_non_fp8_still_works(self): def test_stochastic_rounding_non_fp8_still_works(self):
"""Non-FP8 dtypes on MPS must still work as before (no regression).""" """Non-FP8 dtypes on MPS must still work as before (no regression)."""
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
import comfy.float
x = _make_mps_tensor((32, 32), dtype=torch.float32) x = _make_mps_tensor((32, 32), dtype=torch.float32)
r16 = comfy.float.stochastic_rounding(x, dtype=torch.float16, seed=0) r16 = comfy.float.stochastic_rounding(x, dtype=torch.float16, seed=0)
@ -78,9 +68,6 @@ class TestStochasticRoundingMPS:
def test_stochastic_rounding_fp8_numerical_sanity(self): def test_stochastic_rounding_fp8_numerical_sanity(self):
"""FP8 round-trip (float32 -> fp8 -> float32) should have bounded error.""" """FP8 round-trip (float32 -> fp8 -> float32) should have bounded error."""
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
import comfy.float
x = torch.randn(128, 128, device="mps", dtype=torch.float32) x = torch.randn(128, 128, device="mps", dtype=torch.float32)
x_clamped = torch.clamp(x, min=-448, max=448) # FP8 e4m3fn range x_clamped = torch.clamp(x, min=-448, max=448) # FP8 e4m3fn range
@ -99,15 +86,9 @@ class TestStochasticRoundingMPS:
class TestManualStochasticRoundMPS: class TestManualStochasticRoundMPS:
"""Tests for comfy.float.manual_stochastic_round_to_float8 on MPS device.""" """Tests for comfy.float.manual_stochastic_round_to_float8 on MPS device."""
def test_manual_round_fp8_on_mps_tensor_fails_without_fix(self): def test_manual_round_fp8_on_mps_tensor(self):
"""Calling manual_stochastic_round_to_float8 with MPS generator should fail or be handled.""" """stochastic_rounding handles MPS generator internally without 'Placeholder storage' error."""
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
import comfy.float
x = _make_mps_tensor((16, 16), dtype=torch.float32) x = _make_mps_tensor((16, 16), dtype=torch.float32)
# The generator device matters - this is the root cause of the second error
# (Placeholder storage has not been allocated on MPS device!)
# After fix, stochastic_rounding should handle this internally
result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42) result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42)
assert result.dtype == torch.float8_e4m3fn assert result.dtype == torch.float8_e4m3fn
@ -117,9 +98,6 @@ class TestNVFP4StochasticRoundMPS:
def test_nvfp4_stochastic_round_on_mps(self): def test_nvfp4_stochastic_round_on_mps(self):
"""stochastic_round_quantize_nvfp4 creates FP8 block scales internally.""" """stochastic_round_quantize_nvfp4 creates FP8 block scales internally."""
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
import comfy.float
# NVFP4 requires 2D input with dimensions divisible by 16 # NVFP4 requires 2D input with dimensions divisible by 16
x = torch.randn(32, 32, device="mps", dtype=torch.float32) x = torch.randn(32, 32, device="mps", dtype=torch.float32)
scale = torch.tensor(1.0, device="mps", dtype=torch.float32) scale = torch.tensor(1.0, device="mps", dtype=torch.float32)
@ -138,9 +116,6 @@ class TestQuantOpsMPS:
def test_fp8_layout_quantize_on_mps(self): def test_fp8_layout_quantize_on_mps(self):
"""TensorCoreFP8E4M3Layout.quantize must work with MPS tensors.""" """TensorCoreFP8E4M3Layout.quantize must work with MPS tensors."""
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
from comfy.quant_ops import TensorCoreFP8E4M3Layout
x = _make_mps_tensor((64, 64), dtype=torch.bfloat16) x = _make_mps_tensor((64, 64), dtype=torch.bfloat16)
qdata, params = TensorCoreFP8E4M3Layout.quantize( qdata, params = TensorCoreFP8E4M3Layout.quantize(
x, scale="recalculate", stochastic_rounding=42 x, scale="recalculate", stochastic_rounding=42
@ -151,9 +126,6 @@ class TestQuantOpsMPS:
def test_fp8_layout_quantize_without_stochastic_on_mps(self): def test_fp8_layout_quantize_without_stochastic_on_mps(self):
"""TensorCoreFP8E4M3Layout.quantize with stochastic_rounding=0 uses ck.quantize_per_tensor_fp8.""" """TensorCoreFP8E4M3Layout.quantize with stochastic_rounding=0 uses ck.quantize_per_tensor_fp8."""
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
from comfy.quant_ops import TensorCoreFP8E4M3Layout
x = _make_mps_tensor((64, 64), dtype=torch.bfloat16) x = _make_mps_tensor((64, 64), dtype=torch.bfloat16)
qdata, params = TensorCoreFP8E4M3Layout.quantize( qdata, params = TensorCoreFP8E4M3Layout.quantize(
x, scale="recalculate", stochastic_rounding=0 x, scale="recalculate", stochastic_rounding=0
@ -163,9 +135,6 @@ class TestQuantOpsMPS:
def test_fp8_e5m2_layout_quantize_on_mps(self): def test_fp8_e5m2_layout_quantize_on_mps(self):
"""TensorCoreFP8E5M2Layout.quantize must work with MPS tensors.""" """TensorCoreFP8E5M2Layout.quantize must work with MPS tensors."""
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
from comfy.quant_ops import TensorCoreFP8E5M2Layout
x = _make_mps_tensor((64, 64), dtype=torch.float32) x = _make_mps_tensor((64, 64), dtype=torch.float32)
qdata, params = TensorCoreFP8E5M2Layout.quantize( qdata, params = TensorCoreFP8E5M2Layout.quantize(
x, scale="recalculate", stochastic_rounding=42 x, scale="recalculate", stochastic_rounding=42