mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Fix pylint errors in attention
This commit is contained in:
parent
bbe2ed330c
commit
388dad67d5
@ -17,10 +17,10 @@ if model_management.xformers_enabled():
|
||||
import xformers.ops # pylint: disable=import-error
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
from sageattention import sageattn
|
||||
from sageattention import sageattn # pylint: disable=import-error
|
||||
|
||||
if model_management.flash_attn_enabled():
|
||||
from flash_attn import flash_attn_func
|
||||
from flash_attn import flash_attn_func # pylint: disable=import-error
|
||||
|
||||
from ...cli_args import args
|
||||
from ... import ops
|
||||
@ -377,7 +377,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
mask_out[:, :, :mask.shape[-1]] = mask
|
||||
mask = mask_out[:, :, :mask.shape[-1]]
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) # pylint: disable=possibly-used-before-assignment
|
||||
|
||||
if skip_reshape:
|
||||
out = (
|
||||
@ -393,6 +393,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def pytorch_style_decl(func):
|
||||
@wraps(func)
|
||||
def wrapper(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
@ -412,17 +413,20 @@ def pytorch_style_decl(func):
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@pytorch_style_decl
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
|
||||
|
||||
@pytorch_style_decl
|
||||
def attention_sagemaker(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
return sageattn(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
return sageattn(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) # pylint: disable=possibly-used-before-assignment
|
||||
|
||||
|
||||
@pytorch_style_decl
|
||||
def attention_flash_attn(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
return flash_attn_func(q, k, v)
|
||||
return flash_attn_func(q, k, v) # pylint: disable=possibly-used-before-assignment
|
||||
|
||||
|
||||
optimized_attention = attention_basic
|
||||
|
||||
Loading…
Reference in New Issue
Block a user