mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 18:07:25 +08:00
Add SeedVR2 attention and sampler support
This commit is contained in:
parent
15a500ff6b
commit
2ebacd019d
@ -32,6 +32,14 @@ except ImportError as e:
|
|||||||
raise e
|
raise e
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
|
SAGE_ATTENTION_VARLEN_IS_AVAILABLE = False
|
||||||
|
try:
|
||||||
|
from sageattention import sageattn_varlen
|
||||||
|
SAGE_ATTENTION_VARLEN_IS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
if model_management.sage_attention_enabled():
|
||||||
|
logging.warning("SageAttention variable-length attention is unavailable, using pytorch var-len attention instead.")
|
||||||
|
|
||||||
SAGE_ATTENTION3_IS_AVAILABLE = False
|
SAGE_ATTENTION3_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from sageattn3 import sageattn3_blackwell
|
from sageattn3 import sageattn3_blackwell
|
||||||
@ -40,6 +48,7 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||||
|
FLASH_ATTENTION_VARLEN_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_func
|
from flash_attn import flash_attn_func
|
||||||
FLASH_ATTENTION_IS_AVAILABLE = True
|
FLASH_ATTENTION_IS_AVAILABLE = True
|
||||||
@ -48,6 +57,20 @@ except ImportError:
|
|||||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
FLASH_ATTENTION_VARLEN_IS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
if model_management.flash_attention_enabled() and FLASH_ATTENTION_IS_AVAILABLE:
|
||||||
|
logging.warning("Flash Attention variable-length attention is unavailable, using pytorch var-len attention instead.")
|
||||||
|
|
||||||
|
FLASH_ATTENTION3_IS_AVAILABLE = False
|
||||||
|
try:
|
||||||
|
from flash_attn_interface import flash_attn_varlen_func as flash_attn3_varlen_func
|
||||||
|
FLASH_ATTENTION3_IS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
REGISTERED_ATTENTION_FUNCTIONS = {}
|
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||||
def register_attention_function(name: str, func: Callable):
|
def register_attention_function(name: str, func: Callable):
|
||||||
# avoid replacing existing functions
|
# avoid replacing existing functions
|
||||||
@ -735,28 +758,434 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
_VAR_ATTENTION_NESTED_API_NAME = "nested_tensor_from_jagged"
|
||||||
|
_VAR_ATTENTION_GUARD_MESSAGE = (
|
||||||
|
"SeedVR2 var_attention_pytorch: torch.nested.nested_tensor_from_jagged "
|
||||||
|
"is required by this attention path; the installed PyTorch build "
|
||||||
|
"does not provide it"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _var_attention_max_seqlen(cu_seqlens):
|
||||||
|
return int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item())
|
||||||
|
|
||||||
|
|
||||||
|
def _var_attention_qkv(q, k, v, heads, skip_reshape):
|
||||||
|
if skip_reshape:
|
||||||
|
return q, k, v, q.shape[-1]
|
||||||
|
total_tokens, embed_dim = q.shape
|
||||||
|
head_dim = embed_dim // heads
|
||||||
|
return (
|
||||||
|
q.view(total_tokens, heads, head_dim),
|
||||||
|
k.view(k.shape[0], heads, head_dim),
|
||||||
|
v.view(v.shape[0], heads, head_dim),
|
||||||
|
head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
|
||||||
|
if skip_output_reshape:
|
||||||
|
return out
|
||||||
|
return out.reshape(-1, heads * head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def _use_blackwell_attention():
|
||||||
|
device = model_management.get_torch_device()
|
||||||
|
if device.type != "cuda":
|
||||||
|
return False
|
||||||
|
major, minor = torch.cuda.get_device_capability(device)
|
||||||
|
return (major, minor) >= (12, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
||||||
|
_nested = getattr(torch, "nested", None)
|
||||||
|
if _nested is None or not hasattr(_nested, _VAR_ATTENTION_NESTED_API_NAME):
|
||||||
|
raise RuntimeError(_VAR_ATTENTION_GUARD_MESSAGE)
|
||||||
|
|
||||||
|
if not skip_reshape:
|
||||||
|
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||||
|
total_tokens, embed_dim = q.shape
|
||||||
|
head_dim = embed_dim // heads
|
||||||
|
q = q.view(total_tokens, heads, head_dim)
|
||||||
|
k = k.view(k.shape[0], heads, head_dim)
|
||||||
|
v = v.view(v.shape[0], heads, head_dim)
|
||||||
|
|
||||||
|
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||||
|
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||||
|
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||||
|
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
|
|
||||||
|
out = out.transpose(1, 2)
|
||||||
|
if not skip_output_reshape:
|
||||||
|
return out.values().reshape(-1, heads * (q.shape[-1]))
|
||||||
|
return out.values()
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
|
||||||
|
if cu_seqlens.dtype not in (torch.int32, torch.int64):
|
||||||
|
raise ValueError(f"{name} must use an integer dtype")
|
||||||
|
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
|
||||||
|
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
|
||||||
|
if cu_seqlens[0].item() != 0:
|
||||||
|
raise ValueError(f"{name} must start at 0")
|
||||||
|
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
|
||||||
|
raise ValueError(f"{name} must be strictly increasing")
|
||||||
|
if cu_seqlens[-1].item() != token_count:
|
||||||
|
raise ValueError(f"{name} does not match token count")
|
||||||
|
|
||||||
|
|
||||||
|
def _split_indices(cu_seqlens):
|
||||||
|
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
|
||||||
|
|
||||||
|
|
||||||
|
def var_attention_pytorch_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
||||||
|
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||||
|
|
||||||
|
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
|
||||||
|
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
|
||||||
|
if cu_seqlens_k[-1].item() != v.shape[0]:
|
||||||
|
raise ValueError("cu_seqlens_k does not match v token count")
|
||||||
|
|
||||||
|
q_split_indices = _split_indices(cu_seqlens_q)
|
||||||
|
k_split_indices = _split_indices(cu_seqlens_k)
|
||||||
|
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
|
||||||
|
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
|
||||||
|
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
|
||||||
|
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
|
||||||
|
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
|
||||||
|
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
|
||||||
|
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
|
||||||
|
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
|
||||||
|
out_i = comfy.ops.scaled_dot_product_attention(q_i, k_i, v_i, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
|
out.append(out_i.squeeze(0).permute(1, 0, 2))
|
||||||
|
|
||||||
|
out = torch.cat(out, dim=0)
|
||||||
|
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||||
|
|
||||||
|
@torch._dynamo.disable
|
||||||
|
def var_attention_sage(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
if not SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||||
|
return var_attention_pytorch(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
heads,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
)
|
||||||
|
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||||
|
out_dtype = q.dtype
|
||||||
|
if not (q.dtype == k.dtype == v.dtype):
|
||||||
|
k = k.to(q.dtype)
|
||||||
|
v = v.to(q.dtype)
|
||||||
|
fallback_q, fallback_k, fallback_v = q, k, v
|
||||||
|
if q.dtype not in (torch.float16, torch.bfloat16):
|
||||||
|
q = q.to(torch.bfloat16)
|
||||||
|
k = k.to(torch.bfloat16)
|
||||||
|
v = v.to(torch.bfloat16)
|
||||||
|
sm_scale = kwargs.get("softmax_scale")
|
||||||
|
if sm_scale is None:
|
||||||
|
sm_scale = 1.0 / math.sqrt(head_dim)
|
||||||
|
try:
|
||||||
|
out = sageattn_varlen(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q.int(),
|
||||||
|
cu_seqlens_k.int(),
|
||||||
|
_var_attention_max_seqlen(cu_seqlens_q),
|
||||||
|
_var_attention_max_seqlen(cu_seqlens_k),
|
||||||
|
kwargs.get("causal", False),
|
||||||
|
sm_scale,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Error running sage var-len attention: %s, using pytorch var-len attention instead.", e)
|
||||||
|
out = var_attention_pytorch(
|
||||||
|
fallback_q,
|
||||||
|
fallback_k,
|
||||||
|
fallback_v,
|
||||||
|
heads,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
)
|
||||||
|
if out.dtype != out_dtype:
|
||||||
|
out = out.to(out_dtype)
|
||||||
|
return out
|
||||||
|
if out.dtype != out_dtype:
|
||||||
|
out = out.to(out_dtype)
|
||||||
|
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||||
|
|
||||||
|
|
||||||
|
@torch._dynamo.disable
|
||||||
|
def var_attention_sage3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
if not SAGE_ATTENTION3_IS_AVAILABLE:
|
||||||
|
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||||
|
return var_attention_sage(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
heads,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return var_attention_pytorch(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
heads,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
)
|
||||||
|
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||||
|
seq_lens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
|
||||||
|
seq_lens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
|
||||||
|
uniform_q = bool((seq_lens_q == seq_lens_q[0]).all().item())
|
||||||
|
uniform_k = bool((seq_lens_k == seq_lens_k[0]).all().item())
|
||||||
|
if not (uniform_q and uniform_k and seq_lens_q[0] == seq_lens_k[0]):
|
||||||
|
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||||
|
return var_attention_sage(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
heads,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return var_attention_pytorch(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
heads,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
)
|
||||||
|
out_dtype = q.dtype
|
||||||
|
if not (q.dtype == k.dtype == v.dtype):
|
||||||
|
k = k.to(q.dtype)
|
||||||
|
v = v.to(q.dtype)
|
||||||
|
fallback_q, fallback_k, fallback_v = q, k, v
|
||||||
|
if q.dtype not in (torch.float16, torch.bfloat16):
|
||||||
|
q = q.to(torch.bfloat16)
|
||||||
|
k = k.to(torch.bfloat16)
|
||||||
|
v = v.to(torch.bfloat16)
|
||||||
|
batch_size = len(cu_seqlens_q) - 1
|
||||||
|
seq_len = int(seq_lens_q[0].item())
|
||||||
|
q = q.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
|
||||||
|
k = k.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
|
||||||
|
v = v.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
|
||||||
|
try:
|
||||||
|
out = sageattn3_blackwell(q, k, v, is_causal=kwargs.get("causal", False))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Error running SageAttention3 var-len attention: %s, using fallback var-len attention instead.", e)
|
||||||
|
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||||
|
return var_attention_sage(
|
||||||
|
fallback_q,
|
||||||
|
fallback_k,
|
||||||
|
fallback_v,
|
||||||
|
heads,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return var_attention_pytorch(
|
||||||
|
fallback_q,
|
||||||
|
fallback_k,
|
||||||
|
fallback_v,
|
||||||
|
heads,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
)
|
||||||
|
out = out.transpose(1, 2).reshape(-1, heads, head_dim).contiguous()
|
||||||
|
if out.dtype != out_dtype:
|
||||||
|
out = out.to(out_dtype)
|
||||||
|
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||||
|
|
||||||
|
|
||||||
|
@torch._dynamo.disable
|
||||||
|
def var_attention_flash(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
if not FLASH_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||||
|
return var_attention_pytorch(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
heads,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
)
|
||||||
|
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||||
|
max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q)
|
||||||
|
max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k)
|
||||||
|
try:
|
||||||
|
out = flash_attn_varlen_func(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q.int(),
|
||||||
|
cu_seqlens_k=cu_seqlens_k.int(),
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
dropout_p=kwargs.get("dropout_p", 0.0),
|
||||||
|
causal=kwargs.get("causal", False),
|
||||||
|
deterministic=torch.are_deterministic_algorithms_enabled(),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Error running Flash Attention var-len attention: %s, using pytorch var-len attention instead.", e)
|
||||||
|
return var_attention_pytorch(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
heads,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
)
|
||||||
|
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||||
|
|
||||||
|
|
||||||
|
@torch._dynamo.disable
|
||||||
|
def var_attention_flash3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
if not FLASH_ATTENTION3_IS_AVAILABLE:
|
||||||
|
return var_attention_pytorch(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
heads,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
)
|
||||||
|
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||||
|
max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q)
|
||||||
|
max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k)
|
||||||
|
try:
|
||||||
|
out = flash_attn3_varlen_func(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q.int(),
|
||||||
|
cu_seqlens_k=cu_seqlens_k.int(),
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
seqused_q=None,
|
||||||
|
seqused_k=None,
|
||||||
|
softmax_scale=kwargs.get("softmax_scale"),
|
||||||
|
causal=kwargs.get("causal", False),
|
||||||
|
deterministic=torch.are_deterministic_algorithms_enabled(),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Error running Flash Attention 3 var-len attention: %s, using pytorch var-len attention instead.", e)
|
||||||
|
return var_attention_pytorch(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
heads,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
)
|
||||||
|
if isinstance(out, tuple):
|
||||||
|
out = out[0]
|
||||||
|
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||||
|
|
||||||
|
|
||||||
|
@torch._dynamo.disable
|
||||||
|
def var_attention_sub_quad(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
return var_attention_pytorch(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
heads,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch._dynamo.disable
|
||||||
|
def var_attention_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
return var_attention_pytorch_split(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
heads,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
skip_reshape=skip_reshape,
|
||||||
|
skip_output_reshape=skip_output_reshape,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
optimized_var_attention = var_attention_pytorch
|
||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
|
||||||
if model_management.sage_attention_enabled():
|
if model_management.sage_attention_enabled():
|
||||||
logging.info("Using sage attention")
|
logging.info("Using sage attention")
|
||||||
optimized_attention = attention_sage
|
optimized_attention = attention_sage
|
||||||
|
if SAGE_ATTENTION3_IS_AVAILABLE and _use_blackwell_attention():
|
||||||
|
logging.info("Using SageAttention3 for variable-length attention")
|
||||||
|
optimized_var_attention = var_attention_sage3
|
||||||
|
elif SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||||
|
logging.info("Using SageAttention for variable-length attention")
|
||||||
|
optimized_var_attention = var_attention_sage
|
||||||
|
else:
|
||||||
|
logging.info("Using pytorch attention for variable-length attention")
|
||||||
|
optimized_var_attention = var_attention_pytorch
|
||||||
elif model_management.flash_attention_enabled():
|
elif model_management.flash_attention_enabled():
|
||||||
logging.info("Using Flash Attention")
|
logging.info("Using Flash Attention")
|
||||||
optimized_attention = attention_flash
|
optimized_attention = attention_flash
|
||||||
|
if FLASH_ATTENTION_VARLEN_IS_AVAILABLE and model_management.get_torch_device().type == "cuda":
|
||||||
|
logging.info("Using Flash Attention 2 for variable-length attention")
|
||||||
|
optimized_var_attention = var_attention_flash
|
||||||
|
else:
|
||||||
|
logging.info("Using pytorch attention for variable-length attention")
|
||||||
|
optimized_var_attention = var_attention_pytorch
|
||||||
elif model_management.xformers_enabled():
|
elif model_management.xformers_enabled():
|
||||||
logging.info("Using xformers attention")
|
logging.info("Using xformers attention")
|
||||||
optimized_attention = attention_xformers
|
optimized_attention = attention_xformers
|
||||||
elif model_management.pytorch_attention_enabled():
|
elif model_management.pytorch_attention_enabled():
|
||||||
logging.info("Using pytorch attention")
|
logging.info("Using pytorch attention")
|
||||||
optimized_attention = attention_pytorch
|
optimized_attention = attention_pytorch
|
||||||
|
optimized_var_attention = var_attention_pytorch
|
||||||
else:
|
else:
|
||||||
if args.use_split_cross_attention:
|
if args.use_split_cross_attention:
|
||||||
logging.info("Using split optimization for attention")
|
logging.info("Using split optimization for attention")
|
||||||
optimized_attention = attention_split
|
optimized_attention = attention_split
|
||||||
|
optimized_var_attention = var_attention_split
|
||||||
else:
|
else:
|
||||||
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||||
optimized_attention = attention_sub_quad
|
optimized_attention = attention_sub_quad
|
||||||
|
optimized_var_attention = var_attention_sub_quad
|
||||||
|
|
||||||
optimized_attention_masked = optimized_attention
|
optimized_attention_masked = optimized_attention
|
||||||
|
|
||||||
@ -764,15 +1193,25 @@ optimized_attention_masked = optimized_attention
|
|||||||
# register core-supported attention functions
|
# register core-supported attention functions
|
||||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||||
register_attention_function("sage", attention_sage)
|
register_attention_function("sage", attention_sage)
|
||||||
|
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||||
|
register_attention_function("var_attention_sage", var_attention_sage)
|
||||||
if SAGE_ATTENTION3_IS_AVAILABLE:
|
if SAGE_ATTENTION3_IS_AVAILABLE:
|
||||||
register_attention_function("sage3", attention3_sage)
|
register_attention_function("sage3", attention3_sage)
|
||||||
|
register_attention_function("var_attention_sage3", var_attention_sage3)
|
||||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||||
register_attention_function("flash", attention_flash)
|
register_attention_function("flash", attention_flash)
|
||||||
|
if FLASH_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||||
|
register_attention_function("var_attention_flash", var_attention_flash)
|
||||||
|
if FLASH_ATTENTION3_IS_AVAILABLE:
|
||||||
|
register_attention_function("var_attention_flash3", var_attention_flash3)
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
register_attention_function("xformers", attention_xformers)
|
register_attention_function("xformers", attention_xformers)
|
||||||
register_attention_function("pytorch", attention_pytorch)
|
register_attention_function("pytorch", attention_pytorch)
|
||||||
|
register_attention_function("var_attention_pytorch", var_attention_pytorch)
|
||||||
register_attention_function("sub_quad", attention_sub_quad)
|
register_attention_function("sub_quad", attention_sub_quad)
|
||||||
|
register_attention_function("var_attention_sub_quad", var_attention_sub_quad)
|
||||||
register_attention_function("split", attention_split)
|
register_attention_function("split", attention_split)
|
||||||
|
register_attention_function("var_attention_split", var_attention_split)
|
||||||
|
|
||||||
|
|
||||||
def optimized_attention_for_device(device, mask=False, small_input=False):
|
def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||||
@ -1209,5 +1648,3 @@ class SpatialVideoTransformer(SpatialTransformer):
|
|||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
out = x + x_in
|
out = x + x_in
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,7 @@ if model_management.xformers_enabled_vae():
|
|||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
|
|
||||||
def torch_cat_if_needed(xl, dim):
|
def torch_cat_if_needed(xl, dim):
|
||||||
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
|
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
|
||||||
if len(xl) > 1:
|
if len(xl) > 1:
|
||||||
@ -22,7 +23,8 @@ def torch_cat_if_needed(xl, dim):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_timestep_embedding(timesteps, embedding_dim):
|
|
||||||
|
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
|
||||||
"""
|
"""
|
||||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||||
From Fairseq.
|
From Fairseq.
|
||||||
@ -33,11 +35,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
|||||||
assert len(timesteps.shape) == 1
|
assert len(timesteps.shape) == 1
|
||||||
|
|
||||||
half_dim = embedding_dim // 2
|
half_dim = embedding_dim // 2
|
||||||
emb = math.log(10000) / (half_dim - 1)
|
emb = math.log(10000) / (half_dim - downscale_freq_shift)
|
||||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||||
emb = emb.to(device=timesteps.device)
|
emb = emb.to(device=timesteps.device)
|
||||||
emb = timesteps.float()[:, None] * emb[None, :]
|
emb = timesteps.float()[:, None] * emb[None, :]
|
||||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||||
|
if flip_sin_to_cos:
|
||||||
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||||
if embedding_dim % 2 == 1: # zero pad
|
if embedding_dim % 2 == 1: # zero pad
|
||||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||||
return emb
|
return emb
|
||||||
|
|||||||
0
comfy/samplers.py
Executable file → Normal file
0
comfy/samplers.py
Executable file → Normal file
Loading…
Reference in New Issue
Block a user