From 99a1fb6027b7163592a83669b0b1c5aa4657c2b6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 24 Dec 2024 18:05:19 -0500 Subject: [PATCH 1/5] Make fast fp8 take a bit less peak memory. --- comfy/ops.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 8e0694232..06be6b48b 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -255,9 +255,10 @@ def fp8_linear(self, input): tensor_2d = True input = input.unsqueeze(1) - + input_shape = input.shape + input_dtype = input.dtype if len(input.shape) == 3: - w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype) + w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) w = w.t() scale_weight = self.scale_weight @@ -269,23 +270,24 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) - inn = torch.clamp(input, min=-448, max=448).reshape(-1, input.shape[2]).to(dtype) + input = torch.clamp(input, min=-448, max=448, out=input) + input = input.reshape(-1, input_shape[2]).to(dtype) else: scale_input = scale_input.to(input.device) - inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype) + input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype) if bias is not None: - o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) + o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) else: - o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight) + o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight) if isinstance(o, tuple): o = o[0] if tensor_2d: - return o.reshape(input.shape[0], -1) + return o.reshape(input_shape[0], -1) - return o.reshape((-1, input.shape[1], self.weight.shape[0])) + return o.reshape((-1, input_shape[1], self.weight.shape[0])) return None From 1ed75ab30ee2fdef6b3b41ad3061583a0fede723 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 25 Dec 2024 03:29:03 -0500 Subject: [PATCH 2/5] Update nightly pytorch instructions in readme for nvidia. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8311b7b7c..371421617 100644 --- a/README.md +++ b/README.md @@ -189,7 +189,7 @@ Nvidia users should install stable pytorch using this command: This is the command to install pytorch nightly instead which might have performance improvements: -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124``` +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126``` #### Troubleshooting From 0229228f3f75fc4b0d0d4cf3658138eedc2cc2eb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 25 Dec 2024 04:50:34 -0500 Subject: [PATCH 3/5] Clean up the VAE dtypes code. --- comfy/model_management.py | 27 ++++++++++++--------------- comfy/sd.py | 4 ++-- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 33891b929..8320c6ece 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -188,6 +188,12 @@ def is_nvidia(): return True return False +def is_amd(): + global cpu_state + if cpu_state == CPUState.GPU: + if torch.version.hip: + return True + return False MIN_WEIGHT_MEMORY_RATIO = 0.4 if is_nvidia(): @@ -198,27 +204,17 @@ if args.use_pytorch_cross_attention: ENABLE_PYTORCH_ATTENTION = True XFORMERS_IS_AVAILABLE = False -VAE_DTYPES = [torch.float32] - try: if is_nvidia(): if int(torch_version[0]) >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True - if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: - VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES if is_intel_xpu(): if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True except: pass -if is_intel_xpu(): - VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES - -if args.cpu_vae: - VAE_DTYPES = [torch.float32] - if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) @@ -754,7 +750,6 @@ def vae_offload_device(): return torch.device("cpu") def vae_dtype(device=None, allowed_dtypes=[]): - global VAE_DTYPES if args.fp16_vae: return torch.float16 elif args.bf16_vae: @@ -763,12 +758,14 @@ def vae_dtype(device=None, allowed_dtypes=[]): return torch.float32 for d in allowed_dtypes: - if d == torch.float16 and should_use_fp16(device, prioritize_performance=False): - return d - if d in VAE_DTYPES: + if d == torch.float16 and should_use_fp16(device): return d - return VAE_DTYPES[0] + # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32 + if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device): + return d + + return torch.float32 def get_autocast_device(dev): if hasattr(dev, 'type'): diff --git a/comfy/sd.py b/comfy/sd.py index de3ce677c..55f91116f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -111,7 +111,7 @@ class CLIP: model_management.load_models_gpu([self.patcher], force_full_load=True) self.layer_idx = None self.use_clip_schedule = False - logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device'])) + logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype)) def clone(self): n = CLIP(no_init=True) @@ -402,7 +402,7 @@ class VAE: self.output_device = model_management.intermediate_device() self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) - logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) + logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) def vae_encode_crop_pixels(self, pixels): downscale_ratio = self.spacial_compression_encode() From b486885e0866b1fc37b767a7ff04c1f40acb5ac4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 25 Dec 2024 05:18:50 -0500 Subject: [PATCH 4/5] Disable bfloat16 on older mac. --- comfy/model_management.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 8320c6ece..ce241e17f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -886,14 +886,19 @@ def pytorch_attention_flash_attention(): return True return False +def mac_version(): + try: + return tuple(int(n) for n in platform.mac_ver()[0].split(".")) + except: + return None + def force_upcast_attention_dtype(): upcast = args.force_upcast_attention - try: - macos_version = tuple(int(n) for n in platform.mac_ver()[0].split(".")) - if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS - upcast = True - except: - pass + + macos_version = mac_version() + if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS + upcast = True + if upcast: return torch.float32 else: @@ -1034,6 +1039,8 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False if mps_mode(): + if mac_version() < (14,): + return False return True if cpu_mode(): From 19a64d62918c68b800de7277472c3b039beaa126 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 25 Dec 2024 05:32:51 -0500 Subject: [PATCH 5/5] Cleanup some mac related code. --- comfy/model_management.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index ce241e17f..db2a61395 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -969,17 +969,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if FORCE_FP16: return True - if device is not None: - if is_device_mps(device): - return True - if FORCE_FP32: return False if directml_enabled: return False - if mps_mode(): + if (device is not None and is_device_mps(device)) or mps_mode(): return True if cpu_mode(): @@ -1028,17 +1024,13 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow return False - if device is not None: - if is_device_mps(device): - return True - if FORCE_FP32: return False if directml_enabled: return False - if mps_mode(): + if (device is not None and is_device_mps(device)) or mps_mode(): if mac_version() < (14,): return False return True