mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-31 05:53:42 +08:00
Use NORMAL_VRAM instead of SHARED for MPS memory management
On Apple Silicon, SHARED vram state is treated like HIGH_VRAM in several code paths, which keeps all models loaded on the GPU device. Since MPS uses unified memory (GPU memory IS system memory), this starves the system of memory for inference activations when running large models (e.g. Wan 2.2 14B I2V loads 13.6GB model + 21.6GB text encoder = 35GB, leaving only 13GB for activations on a 48GB system). This caused GPU hangs and kernel panics on macOS when generating video with 14B parameter models. Changes: - Set MPS vram_state to NORMAL_VRAM instead of SHARED, enabling smart memory management to offload unused models during inference - Remove SHARED from unet_inital_load_device HIGH_VRAM fast path, so model loading respects available memory on MPS - Remove MPS shortcut in text_encoder_initial_device that forced the text encoder to stay on GPU regardless of memory pressure On unified memory, "offloading to CPU" is nearly free (same physical memory pool, just a device pointer change), so NORMAL_VRAM behavior is strictly better — it lets ComfyUI intelligently manage what stays on the GPU device based on available memory. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
a0302cc6a8
commit
e1ec301db3
@ -432,7 +432,7 @@ if cpu_state != CPUState.GPU:
|
||||
vram_state = VRAMState.DISABLED
|
||||
|
||||
if cpu_state == CPUState.MPS:
|
||||
vram_state = VRAMState.SHARED
|
||||
vram_state = VRAMState.NORMAL_VRAM
|
||||
|
||||
logging.info(f"Set vram state to: {vram_state.name}")
|
||||
|
||||
@ -840,7 +840,7 @@ def unet_offload_device():
|
||||
|
||||
def unet_inital_load_device(parameters, dtype):
|
||||
torch_dev = get_torch_device()
|
||||
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
|
||||
if vram_state == VRAMState.HIGH_VRAM:
|
||||
return torch_dev
|
||||
|
||||
cpu_dev = torch.device("cpu")
|
||||
@ -957,9 +957,6 @@ def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
||||
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
||||
return offload_device
|
||||
|
||||
if is_device_mps(load_device):
|
||||
return load_device
|
||||
|
||||
mem_l = get_free_memory(load_device)
|
||||
mem_o = get_free_memory(offload_device)
|
||||
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user