fix: run text encoders on MPS GPU instead of CPU for Apple Silicon

On Apple Silicon, `vram_state` is set to `VRAMState.SHARED` because
CPU and GPU share unified memory. However, `text_encoder_device()`
only checked for `HIGH_VRAM` and `NORMAL_VRAM`, causing all text
encoders to fall back to CPU on MPS devices.

Adding `VRAMState.SHARED` to the condition allows non-quantized text
encoders (e.g. bf16 Gemma 3 12B) to run on the MPS GPU, providing
significant speedup for text encoding and prompt generation.

Note: quantized models (fp4/fp8) that use float8_e4m3fn internally
will still fall back to CPU via the `supports_cast()` check in
`CLIP.__init__()`, since MPS does not support fp8 dtypes.

Made-with: Cursor
This commit is contained in:
Anton Bukov 2026-03-07 01:16:49 +04:00
parent 34e55f0061
commit 65202e091e

View File

@ -939,7 +939,7 @@ def text_encoder_offload_device():
def text_encoder_device():
if args.gpu_only:
return get_torch_device()
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED):
if should_use_fp16(prioritize_performance=False):
return get_torch_device()
else: