diff --git a/comfy/model_management.py b/comfy/model_management.py index ce079cf2f..c997f4055 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -460,7 +460,7 @@ if cpu_state != CPUState.GPU: vram_state = VRAMState.DISABLED if cpu_state == CPUState.MPS: - vram_state = VRAMState.SHARED + vram_state = VRAMState.NORMAL_VRAM logging.info(f"Set vram state to: {vram_state.name}") @@ -900,7 +900,7 @@ def unet_inital_load_device(parameters, dtype): return cpu_dev torch_dev = get_torch_device() - if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: + if vram_state == VRAMState.HIGH_VRAM: return torch_dev if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM: @@ -1019,9 +1019,6 @@ def text_encoder_initial_device(load_device, offload_device, model_size=0): if load_device == offload_device or model_size <= 1024 * 1024 * 1024: return offload_device - if is_device_mps(load_device): - return load_device - mem_l = get_free_memory(load_device) mem_o = get_free_memory(offload_device) if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l: