Use NORMAL_VRAM instead of SHARED for MPS memory management

On Apple Silicon, SHARED vram state is treated like HIGH_VRAM in
several code paths, which keeps all models loaded on the GPU device.
Since MPS uses unified memory (GPU memory IS system memory), this
starves the system of memory for inference activations when running
large models (e.g. Wan 2.2 14B I2V loads 13.6GB model + 21.6GB text
encoder = 35GB, leaving only 13GB for activations on a 48GB system).

This caused GPU hangs and kernel panics on macOS when generating video
with 14B parameter models.

Changes:
- Set MPS vram_state to NORMAL_VRAM instead of SHARED, enabling smart
  memory management to offload unused models during inference
- Remove SHARED from unet_inital_load_device HIGH_VRAM fast path, so
  model loading respects available memory on MPS
- Remove MPS shortcut in text_encoder_initial_device that forced the
  text encoder to stay on GPU regardless of memory pressure

On unified memory, "offloading to CPU" is nearly free (same physical
memory pool, just a device pointer change), so NORMAL_VRAM behavior
is strictly better — it lets ComfyUI intelligently manage what stays
on the GPU device based on available memory.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Tashdid Khan 2026-02-09 23:50:46 -05:00
parent a0302cc6a8
commit e1ec301db3

View File

@ -432,7 +432,7 @@ if cpu_state != CPUState.GPU:
vram_state = VRAMState.DISABLED
if cpu_state == CPUState.MPS:
vram_state = VRAMState.SHARED
vram_state = VRAMState.NORMAL_VRAM
logging.info(f"Set vram state to: {vram_state.name}")
@ -840,7 +840,7 @@ def unet_offload_device():
def unet_inital_load_device(parameters, dtype):
torch_dev = get_torch_device()
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
if vram_state == VRAMState.HIGH_VRAM:
return torch_dev
cpu_dev = torch.device("cpu")
@ -957,9 +957,6 @@ def text_encoder_initial_device(load_device, offload_device, model_size=0):
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
return offload_device
if is_device_mps(load_device):
return load_device
mem_l = get_free_memory(load_device)
mem_o = get_free_memory(offload_device)
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l: