This commit is contained in:
kiruno 2026-03-25 01:22:31 +08:00 committed by GitHub
commit dd05437a9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 5 deletions

View File

@ -55,6 +55,12 @@ def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
# MPS workaround: if we're on MPS and trying to create float8, force CPU
if value.device.type == "mps":
cpu_value = value.to("cpu")
# Recursive call on CPU
return stochastic_rounding(cpu_value, dtype, seed)
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
output = torch.empty_like(value, dtype=dtype)

View File

@ -52,11 +52,6 @@ cpu_state = CPUState.GPU
total_vram = 0
# Training Related State
in_training = False
def get_supported_float8_types():
float8_types = []
try:
@ -1287,6 +1282,24 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
target_device = device
if target_device is not None and is_device_mps(target_device):
is_quantized = hasattr(weight, "storage_dtype")
real_dtype = weight.storage_dtype if is_quantized else weight.dtype
if real_dtype in FLOAT8_TYPES:
if is_quantized:
# MPS workaround: ensure everything is on CPU before dequantizing
# This prevents 'Trying to convert Float8_e4m3fn to the MPS backend' errors
# if the tensor or its scale are partially on MPS.
if weight.device.type != "cpu":
weight = weight.to("cpu")
weight = weight.dequantize()
if weight.dtype in FLOAT8_TYPES:
weight = weight.float()
weight = weight.to(target_device, dtype=torch.float16)
if dtype is None or dtype in FLOAT8_TYPES:
dtype = torch.float16
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype: