mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 05:40:15 +08:00
Fix pylint errors in attention
This commit is contained in:
parent
bbe2ed330c
commit
388dad67d5
@ -17,10 +17,10 @@ if model_management.xformers_enabled():
|
|||||||
import xformers.ops # pylint: disable=import-error
|
import xformers.ops # pylint: disable=import-error
|
||||||
|
|
||||||
if model_management.sage_attention_enabled():
|
if model_management.sage_attention_enabled():
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn # pylint: disable=import-error
|
||||||
|
|
||||||
if model_management.flash_attn_enabled():
|
if model_management.flash_attn_enabled():
|
||||||
from flash_attn import flash_attn_func
|
from flash_attn import flash_attn_func # pylint: disable=import-error
|
||||||
|
|
||||||
from ...cli_args import args
|
from ...cli_args import args
|
||||||
from ... import ops
|
from ... import ops
|
||||||
@ -377,7 +377,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||||||
mask_out[:, :, :mask.shape[-1]] = mask
|
mask_out[:, :, :mask.shape[-1]] = mask
|
||||||
mask = mask_out[:, :, :mask.shape[-1]]
|
mask = mask_out[:, :, :mask.shape[-1]]
|
||||||
|
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) # pylint: disable=possibly-used-before-assignment
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
out = (
|
out = (
|
||||||
@ -393,6 +393,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def pytorch_style_decl(func):
|
def pytorch_style_decl(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def wrapper(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
@ -412,17 +413,20 @@ def pytorch_style_decl(func):
|
|||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
@pytorch_style_decl
|
@pytorch_style_decl
|
||||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
|
|
||||||
|
|
||||||
@pytorch_style_decl
|
@pytorch_style_decl
|
||||||
def attention_sagemaker(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_sagemaker(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
return sageattn(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
return sageattn(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) # pylint: disable=possibly-used-before-assignment
|
||||||
|
|
||||||
|
|
||||||
@pytorch_style_decl
|
@pytorch_style_decl
|
||||||
def attention_flash_attn(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
def attention_flash_attn(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
return flash_attn_func(q, k, v)
|
return flash_attn_func(q, k, v) # pylint: disable=possibly-used-before-assignment
|
||||||
|
|
||||||
|
|
||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user