Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2024-12-18 10:16:43 +03:00 committed by GitHub
commit c062723ca5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 6 deletions

View File

@ -104,6 +104,7 @@ attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")

View File

@ -15,6 +15,9 @@ if model_management.xformers_enabled():
import xformers
import xformers.ops
if model_management.sage_attention_enabled():
from sageattention import sageattn
from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init
@ -447,20 +450,54 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
return out
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head),
(q, k, v),
)
tensor_layout="NHD"
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if tensor_layout == "HND":
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = out.reshape(b, -1, heads * dim_head)
return out
optimized_attention = attention_basic
if model_management.xformers_enabled():
logging.info("Using xformers cross attention")
if model_management.sage_attention_enabled():
logging.info("Using sage attention")
optimized_attention = attention_sage
elif model_management.xformers_enabled():
logging.info("Using xformers attention")
optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch cross attention")
logging.info("Using pytorch attention")
optimized_attention = attention_pytorch
else:
if args.use_split_cross_attention:
logging.info("Using split optimization for cross attention")
logging.info("Using split optimization for attention")
optimized_attention = attention_split
else:
logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
optimized_attention_masked = optimized_attention

View File

@ -850,6 +850,8 @@ def cast_to_device(tensor, device, dtype, copy=False):
non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
def sage_attention_enabled():
return args.use_sage_attention
def xformers_enabled():
global directml_enabled

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Any
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
import comfy.ldm.common_dit
@ -81,6 +81,7 @@ class Attention(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
batch_size, seq_length, _ = hidden_states.shape
@ -124,6 +125,7 @@ class TransformerBlock(nn.Module):
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
# Self Attention
residual = x
@ -132,6 +134,7 @@ class TransformerBlock(nn.Module):
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
)
x = residual + x
@ -180,6 +183,7 @@ class Llama2_(nn.Module):
mask += causal_mask
else:
mask = causal_mask
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
intermediate = None
if intermediate_output is not None:
@ -191,6 +195,7 @@ class Llama2_(nn.Module):
x=x,
attention_mask=mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
)
if i == intermediate_output:
intermediate = x.clone()