diff --git a/comfy/model_management.py b/comfy/model_management.py index 22e5f98b4..1f16dc8cc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -188,6 +188,12 @@ def is_nvidia(): return True return False +def is_amd(): + global cpu_state + if cpu_state == CPUState.GPU: + if torch.version.hip: + return True + return False MIN_WEIGHT_MEMORY_RATIO = 0.4 if is_nvidia(): @@ -198,27 +204,17 @@ if args.use_pytorch_cross_attention: ENABLE_PYTORCH_ATTENTION = True XFORMERS_IS_AVAILABLE = False -VAE_DTYPES = [torch.float32] - try: if is_nvidia(): if int(torch_version[0]) >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True - if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: - VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES if is_intel_xpu(): if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True except: pass -if is_intel_xpu(): - VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES - -if args.cpu_vae: - VAE_DTYPES = [torch.float32] - if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) @@ -767,7 +763,6 @@ def vae_offload_device(): return torch.device("cpu") def vae_dtype(device=None, allowed_dtypes=[]): - global VAE_DTYPES if args.fp16_vae: return torch.float16 elif args.bf16_vae: @@ -776,12 +771,14 @@ def vae_dtype(device=None, allowed_dtypes=[]): return torch.float32 for d in allowed_dtypes: - if d == torch.float16 and should_use_fp16(device, prioritize_performance=False): - return d - if d in VAE_DTYPES: + if d == torch.float16 and should_use_fp16(device): return d - return VAE_DTYPES[0] + # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32 + if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device): + return d + + return torch.float32 def get_autocast_device(dev): if hasattr(dev, 'type'): @@ -902,14 +899,19 @@ def pytorch_attention_flash_attention(): return True return False +def mac_version(): + try: + return tuple(int(n) for n in platform.mac_ver()[0].split(".")) + except: + return None + def force_upcast_attention_dtype(): upcast = args.force_upcast_attention - try: - macos_version = tuple(int(n) for n in platform.mac_ver()[0].split(".")) - if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS - upcast = True - except: - pass + + macos_version = mac_version() + if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS + upcast = True + if upcast: return torch.float32 else: @@ -980,17 +982,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if FORCE_FP16: return True - if device is not None: - if is_device_mps(device): - return True - if FORCE_FP32: return False if directml_enabled: return False - if mps_mode(): + if (device is not None and is_device_mps(device)) or mps_mode(): return True if cpu_mode(): @@ -1039,17 +1037,15 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow return False - if device is not None: - if is_device_mps(device): - return True - if FORCE_FP32: return False if directml_enabled: return False - if mps_mode(): + if (device is not None and is_device_mps(device)) or mps_mode(): + if mac_version() < (14,): + return False return True if cpu_mode(): diff --git a/comfy/ops.py b/comfy/ops.py index 8e0694232..06be6b48b 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -255,9 +255,10 @@ def fp8_linear(self, input): tensor_2d = True input = input.unsqueeze(1) - + input_shape = input.shape + input_dtype = input.dtype if len(input.shape) == 3: - w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype) + w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) w = w.t() scale_weight = self.scale_weight @@ -269,23 +270,24 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) - inn = torch.clamp(input, min=-448, max=448).reshape(-1, input.shape[2]).to(dtype) + input = torch.clamp(input, min=-448, max=448, out=input) + input = input.reshape(-1, input_shape[2]).to(dtype) else: scale_input = scale_input.to(input.device) - inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype) + input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype) if bias is not None: - o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) + o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) else: - o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight) + o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight) if isinstance(o, tuple): o = o[0] if tensor_2d: - return o.reshape(input.shape[0], -1) + return o.reshape(input_shape[0], -1) - return o.reshape((-1, input.shape[1], self.weight.shape[0])) + return o.reshape((-1, input_shape[1], self.weight.shape[0])) return None diff --git a/comfy/sd.py b/comfy/sd.py index de3ce677c..55f91116f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -111,7 +111,7 @@ class CLIP: model_management.load_models_gpu([self.patcher], force_full_load=True) self.layer_idx = None self.use_clip_schedule = False - logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device'])) + logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype)) def clone(self): n = CLIP(no_init=True) @@ -402,7 +402,7 @@ class VAE: self.output_device = model_management.intermediate_device() self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) - logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) + logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) def vae_encode_crop_pixels(self, pixels): downscale_ratio = self.spacial_compression_encode()