diff --git a/comfy/float.py b/comfy/float.py index 184b3d6d0..3b2ac5895 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -55,6 +55,12 @@ def stochastic_rounding(value, dtype, seed=0): if dtype == torch.bfloat16: return value.to(dtype=torch.bfloat16) if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: + # MPS workaround: if we're on MPS and trying to create float8, force CPU + if value.device.type == "mps": + cpu_value = value.to("cpu") + # Recursive call on CPU + return stochastic_rounding(cpu_value, dtype, seed) + generator = torch.Generator(device=value.device) generator.manual_seed(seed) output = torch.empty_like(value, dtype=dtype) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2c250dacc..5ea48709f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -52,11 +52,6 @@ cpu_state = CPUState.GPU total_vram = 0 - -# Training Related State -in_training = False - - def get_supported_float8_types(): float8_types = [] try: @@ -1287,6 +1282,24 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None): def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): + target_device = device + if target_device is not None and is_device_mps(target_device): + is_quantized = hasattr(weight, "storage_dtype") + real_dtype = weight.storage_dtype if is_quantized else weight.dtype + if real_dtype in FLOAT8_TYPES: + if is_quantized: + # MPS workaround: ensure everything is on CPU before dequantizing + # This prevents 'Trying to convert Float8_e4m3fn to the MPS backend' errors + # if the tensor or its scale are partially on MPS. + if weight.device.type != "cpu": + weight = weight.to("cpu") + weight = weight.dequantize() + if weight.dtype in FLOAT8_TYPES: + weight = weight.float() + weight = weight.to(target_device, dtype=torch.float16) + if dtype is None or dtype in FLOAT8_TYPES: + dtype = torch.float16 + if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: