From 388dad67d57009c612595fee082187742ebdacfd Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Wed, 9 Oct 2024 09:26:02 -0700 Subject: [PATCH] Fix pylint errors in attention --- comfy/ldm/modules/attention.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index c704ed12d..51004fe6b 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -17,10 +17,10 @@ if model_management.xformers_enabled(): import xformers.ops # pylint: disable=import-error if model_management.sage_attention_enabled(): - from sageattention import sageattn + from sageattention import sageattn # pylint: disable=import-error if model_management.flash_attn_enabled(): - from flash_attn import flash_attn_func + from flash_attn import flash_attn_func # pylint: disable=import-error from ...cli_args import args from ... import ops @@ -377,7 +377,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh mask_out[:, :, :mask.shape[-1]] = mask mask = mask_out[:, :, :mask.shape[-1]] - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) # pylint: disable=possibly-used-before-assignment if skip_reshape: out = ( @@ -393,6 +393,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh return out + def pytorch_style_decl(func): @wraps(func) def wrapper(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): @@ -412,17 +413,20 @@ def pytorch_style_decl(func): return wrapper + @pytorch_style_decl def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + @pytorch_style_decl def attention_sagemaker(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): - return sageattn(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + return sageattn(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) # pylint: disable=possibly-used-before-assignment + @pytorch_style_decl def attention_flash_attn(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): - return flash_attn_func(q, k, v) + return flash_attn_func(q, k, v) # pylint: disable=possibly-used-before-assignment optimized_attention = attention_basic