Fix errors in model_management.py

This commit is contained in:
Max Tretikov 2024-06-14 00:17:47 -06:00
parent 14da37cdf0
commit a919272e3b

View File

@ -49,7 +49,7 @@ if args.deterministic:
directml_enabled = False
if args.directml is not None:
import torch_directml
import torch_directml # pylint: disable=import-error
directml_enabled = True
device_index = args.directml
@ -62,7 +62,7 @@ if args.directml is not None:
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
try:
import intel_extension_for_pytorch as ipex
import intel_extension_for_pytorch as ipex # pylint: disable=import-error
if torch.xpu.is_available():
xpu_available = True
@ -123,10 +123,8 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_total = 1024 * 1024 * 1024 # TODO
mem_total_torch = mem_total
elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
mem_total_torch = mem_reserved
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_total
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@ -162,8 +160,7 @@ if args.disable_xformers:
XFORMERS_IS_AVAILABLE = False
else:
try:
import xformers
import xformers.ops
import xformers # pylint: disable=import-error
XFORMERS_IS_AVAILABLE = True
try:
@ -821,12 +818,8 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = 1024 * 1024 * 1024 # TODO
mem_free_torch = mem_free_total
elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_torch = mem_reserved - mem_active
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_total = mem_free_xpu + mem_free_torch
mem_free_total = torch.xpu.get_device_properties(dev).total_memory
mem_free_torch = mem_free_total
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']