Fix errors in model_management.py

This commit is contained in:
Max Tretikov 2024-06-14 00:17:47 -06:00
parent 14da37cdf0
commit a919272e3b

View File

@ -49,7 +49,7 @@ if args.deterministic:
directml_enabled = False directml_enabled = False
if args.directml is not None: if args.directml is not None:
import torch_directml import torch_directml # pylint: disable=import-error
directml_enabled = True directml_enabled = True
device_index = args.directml device_index = args.directml
@ -62,7 +62,7 @@ if args.directml is not None:
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default. lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex # pylint: disable=import-error
if torch.xpu.is_available(): if torch.xpu.is_available():
xpu_available = True xpu_available = True
@ -123,10 +123,8 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_total = 1024 * 1024 * 1024 # TODO mem_total = 1024 * 1024 * 1024 # TODO
mem_total_torch = mem_total mem_total_torch = mem_total
elif is_intel_xpu(): elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
mem_total_torch = mem_reserved
mem_total = torch.xpu.get_device_properties(dev).total_memory mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_total
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
@ -162,8 +160,7 @@ if args.disable_xformers:
XFORMERS_IS_AVAILABLE = False XFORMERS_IS_AVAILABLE = False
else: else:
try: try:
import xformers import xformers # pylint: disable=import-error
import xformers.ops
XFORMERS_IS_AVAILABLE = True XFORMERS_IS_AVAILABLE = True
try: try:
@ -821,12 +818,8 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = 1024 * 1024 * 1024 # TODO mem_free_total = 1024 * 1024 * 1024 # TODO
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
elif is_intel_xpu(): elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev) mem_free_total = torch.xpu.get_device_properties(dev).total_memory
mem_active = stats['active_bytes.all.current'] mem_free_torch = mem_free_total
mem_reserved = stats['reserved_bytes.all.current']
mem_free_torch = mem_reserved - mem_active
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_total = mem_free_xpu + mem_free_torch
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current'] mem_active = stats['active_bytes.all.current']