From e1ec301db3a1b1ba36bab2a23eba074677ae30f9 Mon Sep 17 00:00:00 2001 From: Tashdid Khan Date: Mon, 9 Feb 2026 23:50:46 -0500 Subject: [PATCH] Use NORMAL_VRAM instead of SHARED for MPS memory management MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit On Apple Silicon, SHARED vram state is treated like HIGH_VRAM in several code paths, which keeps all models loaded on the GPU device. Since MPS uses unified memory (GPU memory IS system memory), this starves the system of memory for inference activations when running large models (e.g. Wan 2.2 14B I2V loads 13.6GB model + 21.6GB text encoder = 35GB, leaving only 13GB for activations on a 48GB system). This caused GPU hangs and kernel panics on macOS when generating video with 14B parameter models. Changes: - Set MPS vram_state to NORMAL_VRAM instead of SHARED, enabling smart memory management to offload unused models during inference - Remove SHARED from unet_inital_load_device HIGH_VRAM fast path, so model loading respects available memory on MPS - Remove MPS shortcut in text_encoder_initial_device that forced the text encoder to stay on GPU regardless of memory pressure On unified memory, "offloading to CPU" is nearly free (same physical memory pool, just a device pointer change), so NORMAL_VRAM behavior is strictly better — it lets ComfyUI intelligently manage what stays on the GPU device based on available memory. Co-Authored-By: Claude Opus 4.6 --- comfy/model_management.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index b6291f340..a4a86b5a2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -432,7 +432,7 @@ if cpu_state != CPUState.GPU: vram_state = VRAMState.DISABLED if cpu_state == CPUState.MPS: - vram_state = VRAMState.SHARED + vram_state = VRAMState.NORMAL_VRAM logging.info(f"Set vram state to: {vram_state.name}") @@ -840,7 +840,7 @@ def unet_offload_device(): def unet_inital_load_device(parameters, dtype): torch_dev = get_torch_device() - if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: + if vram_state == VRAMState.HIGH_VRAM: return torch_dev cpu_dev = torch.device("cpu") @@ -957,9 +957,6 @@ def text_encoder_initial_device(load_device, offload_device, model_size=0): if load_device == offload_device or model_size <= 1024 * 1024 * 1024: return offload_device - if is_device_mps(load_device): - return load_device - mem_l = get_free_memory(load_device) mem_o = get_free_memory(offload_device) if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l: