From 904bf58e7d27eb254d20879e306042653debc4b3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 21 Aug 2024 14:01:41 -0400 Subject: [PATCH 1/3] Make --fast work on pytorch nightly. --- comfy/ops.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index bd84a804c..5fef7cee7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -254,16 +254,33 @@ def fp8_linear(self, input): non_blocking = comfy.model_management.device_supports_non_blocking(input.device) w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t() + scale_weight = self.scale_weight + scale_input = self.scale_input + if scale_weight is None: + scale_weight = torch.ones((1), device=input.device, dtype=torch.float32) + if scale_input is None: + scale_input = scale_weight + if scale_input is None: + scale_input = torch.ones((1), device=input.device, dtype=torch.float32) + if self.bias is not None: - o, _ = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking)) + o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking), scale_a=scale_input, scale_b=scale_weight) else: - o, _ = torch._scaled_mm(inn, w, out_dtype=input.dtype) + o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight) + + if isinstance(o, tuple): + o = o[0] return o.reshape((-1, input.shape[1], self.weight.shape[0])) return None class fp8_ops(manual_cast): class Linear(manual_cast.Linear): + def reset_parameters(self): + self.scale_weight = None + self.scale_input = None + return None + def forward_comfy_cast_weights(self, input): out = fp8_linear(self, input) if out is not None: From 015f73dc4941ae6e01e01b934368f031c7fa8b8d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 21 Aug 2024 16:17:15 -0400 Subject: [PATCH 2/3] Try a different type of flux fp16 fix. --- comfy/ldm/flux/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index da0cf61b1..9820832ba 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -178,7 +178,7 @@ class DoubleStreamBlock(nn.Module): txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) if txt.dtype == torch.float16: - txt = txt.clip(-65504, 65504) + txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) return img, txt @@ -233,7 +233,7 @@ class SingleStreamBlock(nn.Module): output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) x += mod.gate * output if x.dtype == torch.float16: - x = x.clip(-65504, 65504) + x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) return x From a60620dcea1302ef5c7f555e5e16f70b39c234ef Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 21 Aug 2024 16:38:26 -0400 Subject: [PATCH 3/3] Fix slow performance on 10 series Nvidia GPUs. --- comfy/model_management.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index bcc937792..edbe6a8a4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -668,6 +668,7 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo if bf16_supported and weight_dtype == torch.bfloat16: return None + fp16_supported = should_use_fp16(inference_device, prioritize_performance=True) for dt in supported_dtypes: if dt == torch.float16 and fp16_supported: return torch.float16