diff --git a/comfy/float.py b/comfy/float.py index 88c47cd80..841b50803 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -55,6 +55,12 @@ def stochastic_rounding(value, dtype, seed=0): if dtype == torch.bfloat16: return value.to(dtype=torch.bfloat16) if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: + # MPS workaround: if we're on MPS and trying to create float8, force CPU + if value.device.type == "mps": + cpu_value = value.to("cpu") + # Recursive call on CPU + return stochastic_rounding(cpu_value, dtype, seed) + generator = torch.Generator(device=value.device) generator.manual_seed(seed) output = torch.empty_like(value, dtype=dtype) diff --git a/comfy/model_management.py b/comfy/model_management.py index 304931eb0..270c07791 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -55,11 +55,6 @@ cpu_state = CPUState.GPU total_vram = 0 - -# Training Related State -in_training = False - - def get_supported_float8_types(): float8_types = [] try: @@ -1233,6 +1228,23 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str return r + target_device = device + if target_device is not None and is_device_mps(target_device): + is_quantized = hasattr(weight, "storage_dtype") + real_dtype = weight.storage_dtype if is_quantized else weight.dtype + if real_dtype in FLOAT8_TYPES: + if is_quantized: + # MPS workaround: ensure everything is on CPU before dequantizing + # This prevents 'Trying to convert Float8_e4m3fn to the MPS backend' errors + # if the tensor or its scale are partially on MPS. + if weight.device.type != "cpu": + weight = weight.to("cpu") + weight = weight.dequantize() + if weight.dtype in FLOAT8_TYPES: + weight = weight.float() + weight = weight.to(target_device, dtype=torch.float16) + dtype = torch.float16 if dtype is None else dtype + if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: