replaced vram_state = VRAMState.SHARED condition with is_device_mps()

This commit is contained in:
fromfirstbyte 2026-06-25 09:00:24 +03:00
parent abbe30c49e
commit 489fc75791

View File

@ -938,7 +938,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
if lowvram_model_memory == 0:
lowvram_model_memory = 0.1
if vram_set_state == VRAMState.NO_VRAM or (set_vram_to == VRAMState.LOW_VRAM and vram_state == VRAMState.SHARED and not force_full_load):
if vram_set_state == VRAMState.NO_VRAM or (set_vram_to == VRAMState.LOW_VRAM and is_device_mps(torch_dev) and not force_full_load):
lowvram_model_memory = 0.1
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)