mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-16 22:58:19 +08:00
fix: remove hardcoded local paths from MPS FP8 tests
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
edd44a6874
commit
45a2363e6a
@ -7,10 +7,12 @@ These tests verify that:
|
||||
2. The round-trip (quantize on CPU -> result on original device) is numerically sound.
|
||||
3. No "Placeholder storage has not been allocated on MPS device!" errors occur.
|
||||
"""
|
||||
import sys
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import comfy.float
|
||||
from comfy.quant_ops import TensorCoreFP8E4M3Layout, TensorCoreFP8E5M2Layout
|
||||
|
||||
# Skip the entire module if MPS is not available
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not torch.backends.mps.is_available(),
|
||||
@ -30,9 +32,6 @@ class TestStochasticRoundingMPS:
|
||||
|
||||
def test_stochastic_rounding_fp8_e4m3fn_on_mps(self):
|
||||
"""stochastic_rounding must not crash when input is on MPS and target dtype is float8_e4m3fn."""
|
||||
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
|
||||
import comfy.float
|
||||
|
||||
x = _make_mps_tensor((64, 64), dtype=torch.float32)
|
||||
result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42)
|
||||
|
||||
@ -41,9 +40,6 @@ class TestStochasticRoundingMPS:
|
||||
|
||||
def test_stochastic_rounding_fp8_e5m2_on_mps(self):
|
||||
"""stochastic_rounding must not crash when input is on MPS and target dtype is float8_e5m2."""
|
||||
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
|
||||
import comfy.float
|
||||
|
||||
x = _make_mps_tensor((64, 64), dtype=torch.float32)
|
||||
result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e5m2, seed=42)
|
||||
|
||||
@ -52,9 +48,6 @@ class TestStochasticRoundingMPS:
|
||||
|
||||
def test_stochastic_rounding_fp8_result_on_cpu(self):
|
||||
"""Result of FP8 rounding from MPS input should be on CPU (since MPS can't hold FP8)."""
|
||||
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
|
||||
import comfy.float
|
||||
|
||||
x = _make_mps_tensor((32, 32), dtype=torch.float32)
|
||||
result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42)
|
||||
|
||||
@ -63,9 +56,6 @@ class TestStochasticRoundingMPS:
|
||||
|
||||
def test_stochastic_rounding_non_fp8_still_works(self):
|
||||
"""Non-FP8 dtypes on MPS must still work as before (no regression)."""
|
||||
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
|
||||
import comfy.float
|
||||
|
||||
x = _make_mps_tensor((32, 32), dtype=torch.float32)
|
||||
|
||||
r16 = comfy.float.stochastic_rounding(x, dtype=torch.float16, seed=0)
|
||||
@ -78,9 +68,6 @@ class TestStochasticRoundingMPS:
|
||||
|
||||
def test_stochastic_rounding_fp8_numerical_sanity(self):
|
||||
"""FP8 round-trip (float32 -> fp8 -> float32) should have bounded error."""
|
||||
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
|
||||
import comfy.float
|
||||
|
||||
x = torch.randn(128, 128, device="mps", dtype=torch.float32)
|
||||
x_clamped = torch.clamp(x, min=-448, max=448) # FP8 e4m3fn range
|
||||
|
||||
@ -99,15 +86,9 @@ class TestStochasticRoundingMPS:
|
||||
class TestManualStochasticRoundMPS:
|
||||
"""Tests for comfy.float.manual_stochastic_round_to_float8 on MPS device."""
|
||||
|
||||
def test_manual_round_fp8_on_mps_tensor_fails_without_fix(self):
|
||||
"""Calling manual_stochastic_round_to_float8 with MPS generator should fail or be handled."""
|
||||
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
|
||||
import comfy.float
|
||||
|
||||
def test_manual_round_fp8_on_mps_tensor(self):
|
||||
"""stochastic_rounding handles MPS generator internally without 'Placeholder storage' error."""
|
||||
x = _make_mps_tensor((16, 16), dtype=torch.float32)
|
||||
# The generator device matters - this is the root cause of the second error
|
||||
# (Placeholder storage has not been allocated on MPS device!)
|
||||
# After fix, stochastic_rounding should handle this internally
|
||||
result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42)
|
||||
assert result.dtype == torch.float8_e4m3fn
|
||||
|
||||
@ -117,9 +98,6 @@ class TestNVFP4StochasticRoundMPS:
|
||||
|
||||
def test_nvfp4_stochastic_round_on_mps(self):
|
||||
"""stochastic_round_quantize_nvfp4 creates FP8 block scales internally."""
|
||||
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
|
||||
import comfy.float
|
||||
|
||||
# NVFP4 requires 2D input with dimensions divisible by 16
|
||||
x = torch.randn(32, 32, device="mps", dtype=torch.float32)
|
||||
scale = torch.tensor(1.0, device="mps", dtype=torch.float32)
|
||||
@ -138,9 +116,6 @@ class TestQuantOpsMPS:
|
||||
|
||||
def test_fp8_layout_quantize_on_mps(self):
|
||||
"""TensorCoreFP8E4M3Layout.quantize must work with MPS tensors."""
|
||||
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
|
||||
from comfy.quant_ops import TensorCoreFP8E4M3Layout
|
||||
|
||||
x = _make_mps_tensor((64, 64), dtype=torch.bfloat16)
|
||||
qdata, params = TensorCoreFP8E4M3Layout.quantize(
|
||||
x, scale="recalculate", stochastic_rounding=42
|
||||
@ -151,9 +126,6 @@ class TestQuantOpsMPS:
|
||||
|
||||
def test_fp8_layout_quantize_without_stochastic_on_mps(self):
|
||||
"""TensorCoreFP8E4M3Layout.quantize with stochastic_rounding=0 uses ck.quantize_per_tensor_fp8."""
|
||||
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
|
||||
from comfy.quant_ops import TensorCoreFP8E4M3Layout
|
||||
|
||||
x = _make_mps_tensor((64, 64), dtype=torch.bfloat16)
|
||||
qdata, params = TensorCoreFP8E4M3Layout.quantize(
|
||||
x, scale="recalculate", stochastic_rounding=0
|
||||
@ -163,9 +135,6 @@ class TestQuantOpsMPS:
|
||||
|
||||
def test_fp8_e5m2_layout_quantize_on_mps(self):
|
||||
"""TensorCoreFP8E5M2Layout.quantize must work with MPS tensors."""
|
||||
sys.path.insert(0, "/Users/tkhan/comfyui/ComfyUI")
|
||||
from comfy.quant_ops import TensorCoreFP8E5M2Layout
|
||||
|
||||
x = _make_mps_tensor((64, 64), dtype=torch.float32)
|
||||
qdata, params = TensorCoreFP8E5M2Layout.quantize(
|
||||
x, scale="recalculate", stochastic_rounding=42
|
||||
|
||||
Loading…
Reference in New Issue
Block a user