diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c0b68fb8c..2f71703b6 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -130,9 +130,22 @@ class LowVramPatch: self.set_func = set_func def __call__(self, weight): + # Detect SageAttention and skip conversion for compatibility + sage_attention_active = False + try: + import comfy.cli_args + sage_attention_active = hasattr(comfy.cli_args.args, 'use_sage_attention') and \ + comfy.cli_args.args.use_sage_attention + except: + pass + intermediate_dtype = weight.dtype - if self.convert_func is not None: + + # Skip convert_func when SageAttention is active (compatibility mode) + if self.convert_func is not None and not sage_attention_active: weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True) + elif sage_attention_active and self.convert_func is not None: + logging.debug(f"Skipping convert_func for {self.key} (SageAttention compatibility)") if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops intermediate_dtype = torch.float32 @@ -140,10 +153,16 @@ class LowVramPatch: if self.set_func is None: return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key)) else: - return self.set_func(out, seed=string_to_seed(self.key), return_weight=True) + # Skip set_func when SageAttention is active (compatibility mode) + if not sage_attention_active: + return self.set_func(out, seed=string_to_seed(self.key), return_weight=True) + else: + return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key)) out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype) - if self.set_func is not None: + + # Skip set_func when SageAttention is active (compatibility mode) + if self.set_func is not None and not sage_attention_active: return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype) else: return out