From c34403b574ee398d778566d24266070fa3511155 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Wed, 9 Oct 2024 11:21:19 -0700 Subject: [PATCH] Fix invalid device here --- comfy/model_management.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0f43a9b80..a9468a5ba 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -37,6 +37,9 @@ from .model_management_types import ModelManageable model_management_lock = RLock() +# https://github.com/sayakpaul/diffusers-torchao/blob/bade7a6abb1cab9ef44782e6bcfab76d0237ae1f/inference/benchmark_image.py#L3 +# This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer. +torch.set_float32_matmul_precision("high") class VRAMState(Enum): DISABLED = 0 # No vram present: no need to move models to vram @@ -128,10 +131,7 @@ def get_torch_device(): return torch.device("xpu", torch.xpu.current_device()) else: try: - # https://github.com/sayakpaul/diffusers-torchao/blob/bade7a6abb1cab9ef44782e6bcfab76d0237ae1f/inference/benchmark_image.py#L3 - # This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer. - torch.set_float32_matmul_precision("high") - return torch.device(torch.cuda.current_device()) + return torch.device(f"cuda:{torch.cuda.current_device()}") except: warnings.warn("torch.cuda.current_device() did not return a device, returning a CPU torch device") return torch.device("cpu") @@ -1221,12 +1221,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_intel_xpu(): return True - if device is None: - device = torch.device("cuda") - try: - props = torch.cuda.get_device_properties(device) - if props.major >= 8: + props_major = min(torch.cuda.get_device_properties(torch.device(f"cuda:{i}")).major for i in range(torch.cuda.device_count())) + if props_major >= 8: return True except AssertionError: logging.warning("Torch was not compiled with CUDA support")