mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
fix: MPS synchronization and --lowvram/--novram support
- Add torch.mps.synchronize() to synchronize() function, mirroring the existing CUDA and XPU paths - Add torch.mps.synchronize() before torch.mps.empty_cache() in soft_empty_cache() to ensure pending MPS operations complete before releasing cached memory - Allow --lowvram and --novram flags to take effect on MPS devices. Previously, MPS unconditionally set vram_state to SHARED regardless of user flags. Now respects set_vram_to when LOW_VRAM or NO_VRAM is requested, while defaulting to SHARED otherwise.
This commit is contained in:
parent
33b967d30a
commit
209d00c3b2
@ -211,15 +211,9 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
if hasattr(dev, 'type') and dev.type == 'cpu':
|
||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||
mem_total = psutil.virtual_memory().total
|
||||
mem_total_torch = mem_total
|
||||
elif hasattr(dev, 'type') and dev.type == 'mps':
|
||||
mem_total = psutil.virtual_memory().total
|
||||
try:
|
||||
mem_total_torch = torch.mps.recommended_max_memory()
|
||||
except Exception:
|
||||
mem_total_torch = mem_total
|
||||
else:
|
||||
if directml_enabled:
|
||||
mem_total = 1024 * 1024 * 1024 #TODO
|
||||
@ -1494,18 +1488,9 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
if hasattr(dev, 'type') and dev.type == 'cpu':
|
||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||
mem_free_total = psutil.virtual_memory().available
|
||||
mem_free_torch = mem_free_total
|
||||
elif hasattr(dev, 'type') and dev.type == 'mps':
|
||||
try:
|
||||
driver_mem = torch.mps.driver_allocated_memory()
|
||||
current_mem = torch.mps.current_allocated_memory()
|
||||
mem_free_torch = max(driver_mem - current_mem, 0)
|
||||
mem_free_total = psutil.virtual_memory().available + mem_free_torch
|
||||
except Exception:
|
||||
mem_free_total = psutil.virtual_memory().available
|
||||
mem_free_torch = 0
|
||||
else:
|
||||
if directml_enabled:
|
||||
mem_free_total = 1024 * 1024 * 1024 #TODO
|
||||
|
||||
Loading…
Reference in New Issue
Block a user