Fix pylint errors in attention

This commit is contained in:
doctorpangloss 2024-10-09 09:26:02 -07:00
parent bbe2ed330c
commit 388dad67d5

View File

@ -17,10 +17,10 @@ if model_management.xformers_enabled():
import xformers.ops # pylint: disable=import-error
if model_management.sage_attention_enabled():
from sageattention import sageattn
from sageattention import sageattn # pylint: disable=import-error
if model_management.flash_attn_enabled():
from flash_attn import flash_attn_func
from flash_attn import flash_attn_func # pylint: disable=import-error
from ...cli_args import args
from ... import ops
@ -377,7 +377,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
mask_out[:, :, :mask.shape[-1]] = mask
mask = mask_out[:, :, :mask.shape[-1]]
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) # pylint: disable=possibly-used-before-assignment
if skip_reshape:
out = (
@ -393,6 +393,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
return out
def pytorch_style_decl(func):
@wraps(func)
def wrapper(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
@ -412,17 +413,20 @@ def pytorch_style_decl(func):
return wrapper
@pytorch_style_decl
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
@pytorch_style_decl
def attention_sagemaker(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
return sageattn(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
return sageattn(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) # pylint: disable=possibly-used-before-assignment
@pytorch_style_decl
def attention_flash_attn(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
return flash_attn_func(q, k, v)
return flash_attn_func(q, k, v) # pylint: disable=possibly-used-before-assignment
optimized_attention = attention_basic