From 1fc89ccdd9c521621c2ca2d21db95aa78dae7c01 Mon Sep 17 00:00:00 2001 From: Kiruno-lz Date: Wed, 11 Feb 2026 19:25:56 +0800 Subject: [PATCH 1/2] =?UTF-8?q?fix:=20=E6=B7=BB=E5=8A=A0MPS=E8=AE=BE?= =?UTF-8?q?=E5=A4=87=E5=AF=B9float8=E6=95=B0=E6=8D=AE=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E7=9A=84=E5=85=BC=E5=AE=B9=E6=80=A7=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在MPS设备上,PyTorch不支持直接创建或转换float8数据类型。这会导致在尝试将float8张量移至MPS设备时出现运行时错误。 - 在 `stochastic_rounding` 函数中添加递归回退到CPU的逻辑,当检测到MPS设备且目标类型为float8时。 - 在 `cast_to` 函数中添加针对MPS设备的特殊处理,确保float8张量在移动到MPS前先转换为float16。 - 测试了两个workflow:templates-6-key-frames和image_z_image_turbo 相比https://github.com/Comfy-Org/ComfyUI/pull/12378 更具有维护性 --- comfy/float.py | 6 ++++++ comfy/model_management.py | 22 +++++++++++++++++----- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index 88c47cd80..841b50803 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -55,6 +55,12 @@ def stochastic_rounding(value, dtype, seed=0): if dtype == torch.bfloat16: return value.to(dtype=torch.bfloat16) if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: + # MPS workaround: if we're on MPS and trying to create float8, force CPU + if value.device.type == "mps": + cpu_value = value.to("cpu") + # Recursive call on CPU + return stochastic_rounding(cpu_value, dtype, seed) + generator = torch.Generator(device=value.device) generator.manual_seed(seed) output = torch.empty_like(value, dtype=dtype) diff --git a/comfy/model_management.py b/comfy/model_management.py index 304931eb0..270c07791 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -55,11 +55,6 @@ cpu_state = CPUState.GPU total_vram = 0 - -# Training Related State -in_training = False - - def get_supported_float8_types(): float8_types = [] try: @@ -1233,6 +1228,23 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str return r + target_device = device + if target_device is not None and is_device_mps(target_device): + is_quantized = hasattr(weight, "storage_dtype") + real_dtype = weight.storage_dtype if is_quantized else weight.dtype + if real_dtype in FLOAT8_TYPES: + if is_quantized: + # MPS workaround: ensure everything is on CPU before dequantizing + # This prevents 'Trying to convert Float8_e4m3fn to the MPS backend' errors + # if the tensor or its scale are partially on MPS. + if weight.device.type != "cpu": + weight = weight.to("cpu") + weight = weight.dequantize() + if weight.dtype in FLOAT8_TYPES: + weight = weight.float() + weight = weight.to(target_device, dtype=torch.float16) + dtype = torch.float16 if dtype is None else dtype + if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: From b7cbe8afa528b3c48238b5d8063ece7ff863eb22 Mon Sep 17 00:00:00 2001 From: Kiruno-lz Date: Fri, 6 Mar 2026 16:17:21 +0800 Subject: [PATCH 2/2] fix: Correction: Modify the dtype conversion logic within the `cast_to` function. When the dtype is `None` or `float8`, it should be uniformly converted to `float16` to prevent type mismatches during the conversion of `float8` types. --- comfy/model_management.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9145d06b1..de73d04ea 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1221,7 +1221,8 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if weight.dtype in FLOAT8_TYPES: weight = weight.float() weight = weight.to(target_device, dtype=torch.float16) - dtype = torch.float16 if dtype is None else dtype + if dtype is None or dtype in FLOAT8_TYPES: + dtype = torch.float16 if device is None or weight.device == device: if not copy: