mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
fix: MPS memory reporting and cache synchronization
- Split MPS from CPU in get_total_memory() and get_free_memory() to use torch.mps APIs (recommended_max_memory, driver_allocated_memory, current_allocated_memory) instead of relying solely on psutil - Add torch.mps.synchronize() to synchronize() and soft_empty_cache() - Add inter-node MPS cache flush when torch free memory is low - Move MPS SHARED vram assignment before --lowvram/--novram checks so users can override it for large models
This commit is contained in:
parent
f21f6b2212
commit
33b967d30a
@ -211,9 +211,15 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||
if hasattr(dev, 'type') and dev.type == 'cpu':
|
||||
mem_total = psutil.virtual_memory().total
|
||||
mem_total_torch = mem_total
|
||||
elif hasattr(dev, 'type') and dev.type == 'mps':
|
||||
mem_total = psutil.virtual_memory().total
|
||||
try:
|
||||
mem_total_torch = torch.mps.recommended_max_memory()
|
||||
except Exception:
|
||||
mem_total_torch = mem_total
|
||||
else:
|
||||
if directml_enabled:
|
||||
mem_total = 1024 * 1024 * 1024 #TODO
|
||||
@ -460,7 +466,10 @@ if cpu_state != CPUState.GPU:
|
||||
vram_state = VRAMState.DISABLED
|
||||
|
||||
if cpu_state == CPUState.MPS:
|
||||
vram_state = VRAMState.SHARED
|
||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||
vram_state = set_vram_to
|
||||
else:
|
||||
vram_state = VRAMState.SHARED
|
||||
|
||||
logging.info(f"Set vram state to: {vram_state.name}")
|
||||
|
||||
@ -1485,9 +1494,18 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||
if hasattr(dev, 'type') and dev.type == 'cpu':
|
||||
mem_free_total = psutil.virtual_memory().available
|
||||
mem_free_torch = mem_free_total
|
||||
elif hasattr(dev, 'type') and dev.type == 'mps':
|
||||
try:
|
||||
driver_mem = torch.mps.driver_allocated_memory()
|
||||
current_mem = torch.mps.current_allocated_memory()
|
||||
mem_free_torch = max(driver_mem - current_mem, 0)
|
||||
mem_free_total = psutil.virtual_memory().available + mem_free_torch
|
||||
except Exception:
|
||||
mem_free_total = psutil.virtual_memory().available
|
||||
mem_free_torch = 0
|
||||
else:
|
||||
if directml_enabled:
|
||||
mem_free_total = 1024 * 1024 * 1024 #TODO
|
||||
@ -1756,7 +1774,9 @@ def lora_compute_dtype(device):
|
||||
def synchronize():
|
||||
if cpu_mode():
|
||||
return
|
||||
if is_intel_xpu():
|
||||
if mps_mode():
|
||||
torch.mps.synchronize()
|
||||
elif is_intel_xpu():
|
||||
torch.xpu.synchronize()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
@ -1766,6 +1786,7 @@ def soft_empty_cache(force=False):
|
||||
return
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
torch.mps.synchronize()
|
||||
torch.mps.empty_cache()
|
||||
elif is_intel_xpu():
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
@ -780,6 +780,11 @@ class PromptExecutor:
|
||||
if self.cache_type == CacheType.RAM_PRESSURE:
|
||||
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)
|
||||
comfy.memory_management.extra_ram_release(ram_headroom)
|
||||
elif comfy.model_management.mps_mode():
|
||||
mem_free_total, mem_free_torch = comfy.model_management.get_free_memory(
|
||||
comfy.model_management.get_torch_device(), torch_free_too=True)
|
||||
if mem_free_torch < mem_free_total * 0.25:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
# Send cached UI for intermediate output nodes that weren't executed
|
||||
|
||||
Loading…
Reference in New Issue
Block a user