diff --git a/.github/workflows/test-launch.yml b/.github/workflows/test-launch.yml index 5d665d6af..c56283c2d 100644 --- a/.github/workflows/test-launch.yml +++ b/.github/workflows/test-launch.yml @@ -17,7 +17,7 @@ jobs: path: "ComfyUI" - uses: actions/setup-python@v4 with: - python-version: '3.8' + python-version: '3.9' - name: Install requirements run: | python -m pip install --upgrade pip diff --git a/comfy/model_management.py b/comfy/model_management.py index 1f16dc8cc..ee39dc768 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -86,6 +86,13 @@ try: except: pass +try: + import torch_npu # noqa: F401 + _ = torch.npu.device_count() + npu_available = torch.npu.is_available() +except: + npu_available = False + if args.cpu: cpu_state = CPUState.CPU @@ -97,6 +104,12 @@ def is_intel_xpu(): return True return False +def is_ascend_npu(): + global npu_available + if npu_available: + return True + return False + def get_torch_device(): global directml_enabled global cpu_state @@ -110,6 +123,8 @@ def get_torch_device(): else: if is_intel_xpu(): return torch.device("xpu", torch.xpu.current_device()) + elif is_ascend_npu(): + return torch.device("npu", torch.npu.current_device()) else: return torch.device(torch.cuda.current_device()) @@ -130,6 +145,12 @@ def get_total_memory(dev=None, torch_total_too=False): mem_reserved = stats['reserved_bytes.all.current'] mem_total_torch = mem_reserved mem_total = torch.xpu.get_device_properties(dev).total_memory + elif is_ascend_npu(): + stats = torch.npu.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total_npu = torch.npu.mem_get_info(dev) + mem_total_torch = mem_reserved + mem_total = mem_total_npu else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -209,7 +230,7 @@ try: if int(torch_version[0]) >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True - if is_intel_xpu(): + if is_intel_xpu() or is_ascend_npu(): if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True except: @@ -274,6 +295,8 @@ def get_torch_device_name(device): return "{}".format(device.type) elif is_intel_xpu(): return "{} {}".format(device, torch.xpu.get_device_name(device)) + elif is_ascend_npu(): + return "{} {}".format(device, torch.npu.get_device_name(device)) else: return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) @@ -873,6 +896,8 @@ def xformers_enabled(): return False if is_intel_xpu(): return False + if is_ascend_npu(): + return False if directml_enabled: return False return XFORMERS_IS_AVAILABLE @@ -897,6 +922,8 @@ def pytorch_attention_flash_attention(): return True if is_intel_xpu(): return True + if is_ascend_npu(): + return True return False def mac_version(): @@ -936,6 +963,13 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_torch = mem_reserved - mem_active mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved mem_free_total = mem_free_xpu + mem_free_torch + elif is_ascend_npu(): + stats = torch.npu.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_npu, _ = torch.npu.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_npu + mem_free_torch else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -997,6 +1031,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if is_intel_xpu(): return True + if is_ascend_npu(): + return True + if torch.version.hip: return True @@ -1094,6 +1131,8 @@ def soft_empty_cache(force=False): torch.mps.empty_cache() elif is_intel_xpu(): torch.xpu.empty_cache() + elif is_ascend_npu(): + torch.npu.empty_cache() elif torch.cuda.is_available(): if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda torch.cuda.empty_cache()