diff --git a/comfy/model_management.py b/comfy/model_management.py index 55283f1de..5c3033127 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -49,7 +49,7 @@ if args.deterministic: directml_enabled = False if args.directml is not None: - import torch_directml + import torch_directml # pylint: disable=import-error directml_enabled = True device_index = args.directml @@ -62,7 +62,7 @@ if args.directml is not None: lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default. try: - import intel_extension_for_pytorch as ipex + import intel_extension_for_pytorch as ipex # pylint: disable=import-error if torch.xpu.is_available(): xpu_available = True @@ -123,10 +123,8 @@ def get_total_memory(dev=None, torch_total_too=False): mem_total = 1024 * 1024 * 1024 # TODO mem_total_torch = mem_total elif is_intel_xpu(): - stats = torch.xpu.memory_stats(dev) - mem_reserved = stats['reserved_bytes.all.current'] - mem_total_torch = mem_reserved mem_total = torch.xpu.get_device_properties(dev).total_memory + mem_total_torch = mem_total else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -162,8 +160,7 @@ if args.disable_xformers: XFORMERS_IS_AVAILABLE = False else: try: - import xformers - import xformers.ops + import xformers # pylint: disable=import-error XFORMERS_IS_AVAILABLE = True try: @@ -821,12 +818,8 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_total = 1024 * 1024 * 1024 # TODO mem_free_torch = mem_free_total elif is_intel_xpu(): - stats = torch.xpu.memory_stats(dev) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_torch = mem_reserved - mem_active - mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved - mem_free_total = mem_free_xpu + mem_free_torch + mem_free_total = torch.xpu.get_device_properties(dev).total_memory + mem_free_torch = mem_free_total else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current']