fix: 添加MPS设备对float8数据类型的兼容性处理

在MPS设备上,PyTorch不支持直接创建或转换float8数据类型。这会导致在尝试将float8张量移至MPS设备时出现运行时错误。

- 在 `stochastic_rounding` 函数中添加递归回退到CPU的逻辑,当检测到MPS设备且目标类型为float8时。
- 在 `cast_to` 函数中添加针对MPS设备的特殊处理,确保float8张量在移动到MPS前先转换为float16。
- 测试了两个workflow:templates-6-key-frames和image_z_image_turbo

相比https://github.com/Comfy-Org/ComfyUI/pull/12378 更具有维护性
This commit is contained in:
Kiruno-lz 2026-02-11 19:25:56 +08:00
parent 76a7fa96db
commit 1fc89ccdd9
2 changed files with 23 additions and 5 deletions

View File

@ -55,6 +55,12 @@ def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16) return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
# MPS workaround: if we're on MPS and trying to create float8, force CPU
if value.device.type == "mps":
cpu_value = value.to("cpu")
# Recursive call on CPU
return stochastic_rounding(cpu_value, dtype, seed)
generator = torch.Generator(device=value.device) generator = torch.Generator(device=value.device)
generator.manual_seed(seed) generator.manual_seed(seed)
output = torch.empty_like(value, dtype=dtype) output = torch.empty_like(value, dtype=dtype)

View File

@ -55,11 +55,6 @@ cpu_state = CPUState.GPU
total_vram = 0 total_vram = 0
# Training Related State
in_training = False
def get_supported_float8_types(): def get_supported_float8_types():
float8_types = [] float8_types = []
try: try:
@ -1233,6 +1228,23 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
return r return r
target_device = device
if target_device is not None and is_device_mps(target_device):
is_quantized = hasattr(weight, "storage_dtype")
real_dtype = weight.storage_dtype if is_quantized else weight.dtype
if real_dtype in FLOAT8_TYPES:
if is_quantized:
# MPS workaround: ensure everything is on CPU before dequantizing
# This prevents 'Trying to convert Float8_e4m3fn to the MPS backend' errors
# if the tensor or its scale are partially on MPS.
if weight.device.type != "cpu":
weight = weight.to("cpu")
weight = weight.dequantize()
if weight.dtype in FLOAT8_TYPES:
weight = weight.float()
weight = weight.to(target_device, dtype=torch.float16)
dtype = torch.float16 if dtype is None else dtype
if device is None or weight.device == device: if device is None or weight.device == device:
if not copy: if not copy:
if dtype is None or weight.dtype == dtype: if dtype is None or weight.dtype == dtype: