Add SeedVR2 attention and sampler support

This commit is contained in:
John Pollock 2026-05-25 22:11:32 -05:00
parent 15a500ff6b
commit 2ebacd019d
3 changed files with 445 additions and 4 deletions

View File

@ -32,6 +32,14 @@ except ImportError as e:
raise e
exit(-1)
SAGE_ATTENTION_VARLEN_IS_AVAILABLE = False
try:
from sageattention import sageattn_varlen
SAGE_ATTENTION_VARLEN_IS_AVAILABLE = True
except ImportError:
if model_management.sage_attention_enabled():
logging.warning("SageAttention variable-length attention is unavailable, using pytorch var-len attention instead.")
SAGE_ATTENTION3_IS_AVAILABLE = False
try:
from sageattn3 import sageattn3_blackwell
@ -40,6 +48,7 @@ except ImportError:
pass
FLASH_ATTENTION_IS_AVAILABLE = False
FLASH_ATTENTION_VARLEN_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
FLASH_ATTENTION_IS_AVAILABLE = True
@ -48,6 +57,20 @@ except ImportError:
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
try:
from flash_attn import flash_attn_varlen_func
FLASH_ATTENTION_VARLEN_IS_AVAILABLE = True
except ImportError:
if model_management.flash_attention_enabled() and FLASH_ATTENTION_IS_AVAILABLE:
logging.warning("Flash Attention variable-length attention is unavailable, using pytorch var-len attention instead.")
FLASH_ATTENTION3_IS_AVAILABLE = False
try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn3_varlen_func
FLASH_ATTENTION3_IS_AVAILABLE = True
except ImportError:
pass
REGISTERED_ATTENTION_FUNCTIONS = {}
def register_attention_function(name: str, func: Callable):
# avoid replacing existing functions
@ -735,28 +758,434 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
)
return out
_VAR_ATTENTION_NESTED_API_NAME = "nested_tensor_from_jagged"
_VAR_ATTENTION_GUARD_MESSAGE = (
"SeedVR2 var_attention_pytorch: torch.nested.nested_tensor_from_jagged "
"is required by this attention path; the installed PyTorch build "
"does not provide it"
)
def _var_attention_max_seqlen(cu_seqlens):
return int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item())
def _var_attention_qkv(q, k, v, heads, skip_reshape):
if skip_reshape:
return q, k, v, q.shape[-1]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
return (
q.view(total_tokens, heads, head_dim),
k.view(k.shape[0], heads, head_dim),
v.view(v.shape[0], heads, head_dim),
head_dim,
)
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
if skip_output_reshape:
return out
return out.reshape(-1, heads * head_dim)
def _use_blackwell_attention():
device = model_management.get_torch_device()
if device.type != "cuda":
return False
major, minor = torch.cuda.get_device_capability(device)
return (major, minor) >= (12, 0)
def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
_nested = getattr(torch, "nested", None)
if _nested is None or not hasattr(_nested, _VAR_ATTENTION_NESTED_API_NAME):
raise RuntimeError(_VAR_ATTENTION_GUARD_MESSAGE)
if not skip_reshape:
# assumes 2D q, k,v [total_tokens, embed_dim]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
q = q.view(total_tokens, heads, head_dim)
k = k.view(k.shape[0], heads, head_dim)
v = v.view(v.shape[0], heads, head_dim)
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(1, 2)
if not skip_output_reshape:
return out.values().reshape(-1, heads * (q.shape[-1]))
return out.values()
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
if cu_seqlens.dtype not in (torch.int32, torch.int64):
raise ValueError(f"{name} must use an integer dtype")
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
if cu_seqlens[0].item() != 0:
raise ValueError(f"{name} must start at 0")
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
raise ValueError(f"{name} must be strictly increasing")
if cu_seqlens[-1].item() != token_count:
raise ValueError(f"{name} does not match token count")
def _split_indices(cu_seqlens):
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
def var_attention_pytorch_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
if cu_seqlens_k[-1].item() != v.shape[0]:
raise ValueError("cu_seqlens_k does not match v token count")
q_split_indices = _split_indices(cu_seqlens_q)
k_split_indices = _split_indices(cu_seqlens_k)
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
out = []
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
out_i = comfy.ops.scaled_dot_product_attention(q_i, k_i, v_i, attn_mask=None, dropout_p=0.0, is_causal=False)
out.append(out_i.squeeze(0).permute(1, 0, 2))
out = torch.cat(out, dim=0)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
@torch._dynamo.disable
def var_attention_sage(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
if not SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
out_dtype = q.dtype
if not (q.dtype == k.dtype == v.dtype):
k = k.to(q.dtype)
v = v.to(q.dtype)
fallback_q, fallback_k, fallback_v = q, k, v
if q.dtype not in (torch.float16, torch.bfloat16):
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)
v = v.to(torch.bfloat16)
sm_scale = kwargs.get("softmax_scale")
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(head_dim)
try:
out = sageattn_varlen(
q,
k,
v,
cu_seqlens_q.int(),
cu_seqlens_k.int(),
_var_attention_max_seqlen(cu_seqlens_q),
_var_attention_max_seqlen(cu_seqlens_k),
kwargs.get("causal", False),
sm_scale,
)
except Exception as e:
logging.error("Error running sage var-len attention: %s, using pytorch var-len attention instead.", e)
out = var_attention_pytorch(
fallback_q,
fallback_k,
fallback_v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
)
if out.dtype != out_dtype:
out = out.to(out_dtype)
return out
if out.dtype != out_dtype:
out = out.to(out_dtype)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
@torch._dynamo.disable
def var_attention_sage3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
if not SAGE_ATTENTION3_IS_AVAILABLE:
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
return var_attention_sage(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs,
)
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
seq_lens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
seq_lens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
uniform_q = bool((seq_lens_q == seq_lens_q[0]).all().item())
uniform_k = bool((seq_lens_k == seq_lens_k[0]).all().item())
if not (uniform_q and uniform_k and seq_lens_q[0] == seq_lens_k[0]):
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
return var_attention_sage(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
**kwargs,
)
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
)
out_dtype = q.dtype
if not (q.dtype == k.dtype == v.dtype):
k = k.to(q.dtype)
v = v.to(q.dtype)
fallback_q, fallback_k, fallback_v = q, k, v
if q.dtype not in (torch.float16, torch.bfloat16):
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)
v = v.to(torch.bfloat16)
batch_size = len(cu_seqlens_q) - 1
seq_len = int(seq_lens_q[0].item())
q = q.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
try:
out = sageattn3_blackwell(q, k, v, is_causal=kwargs.get("causal", False))
except Exception as e:
logging.error("Error running SageAttention3 var-len attention: %s, using fallback var-len attention instead.", e)
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
return var_attention_sage(
fallback_q,
fallback_k,
fallback_v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
**kwargs,
)
return var_attention_pytorch(
fallback_q,
fallback_k,
fallback_v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
)
out = out.transpose(1, 2).reshape(-1, heads, head_dim).contiguous()
if out.dtype != out_dtype:
out = out.to(out_dtype)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
@torch._dynamo.disable
def var_attention_flash(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
if not FLASH_ATTENTION_VARLEN_IS_AVAILABLE:
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q)
max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k)
try:
out = flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q.int(),
cu_seqlens_k=cu_seqlens_k.int(),
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=kwargs.get("dropout_p", 0.0),
causal=kwargs.get("causal", False),
deterministic=torch.are_deterministic_algorithms_enabled(),
)
except Exception as e:
logging.error("Error running Flash Attention var-len attention: %s, using pytorch var-len attention instead.", e)
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
@torch._dynamo.disable
def var_attention_flash3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
if not FLASH_ATTENTION3_IS_AVAILABLE:
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q)
max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k)
try:
out = flash_attn3_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q.int(),
cu_seqlens_k=cu_seqlens_k.int(),
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
seqused_q=None,
seqused_k=None,
softmax_scale=kwargs.get("softmax_scale"),
causal=kwargs.get("causal", False),
deterministic=torch.are_deterministic_algorithms_enabled(),
)
except Exception as e:
logging.error("Error running Flash Attention 3 var-len attention: %s, using pytorch var-len attention instead.", e)
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
)
if isinstance(out, tuple):
out = out[0]
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
@torch._dynamo.disable
def var_attention_sub_quad(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
@torch._dynamo.disable
def var_attention_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
return var_attention_pytorch_split(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
optimized_var_attention = var_attention_pytorch
optimized_attention = attention_basic
if model_management.sage_attention_enabled():
logging.info("Using sage attention")
optimized_attention = attention_sage
if SAGE_ATTENTION3_IS_AVAILABLE and _use_blackwell_attention():
logging.info("Using SageAttention3 for variable-length attention")
optimized_var_attention = var_attention_sage3
elif SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
logging.info("Using SageAttention for variable-length attention")
optimized_var_attention = var_attention_sage
else:
logging.info("Using pytorch attention for variable-length attention")
optimized_var_attention = var_attention_pytorch
elif model_management.flash_attention_enabled():
logging.info("Using Flash Attention")
optimized_attention = attention_flash
if FLASH_ATTENTION_VARLEN_IS_AVAILABLE and model_management.get_torch_device().type == "cuda":
logging.info("Using Flash Attention 2 for variable-length attention")
optimized_var_attention = var_attention_flash
else:
logging.info("Using pytorch attention for variable-length attention")
optimized_var_attention = var_attention_pytorch
elif model_management.xformers_enabled():
logging.info("Using xformers attention")
optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention")
optimized_attention = attention_pytorch
optimized_var_attention = var_attention_pytorch
else:
if args.use_split_cross_attention:
logging.info("Using split optimization for attention")
optimized_attention = attention_split
optimized_var_attention = var_attention_split
else:
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
optimized_var_attention = var_attention_sub_quad
optimized_attention_masked = optimized_attention
@ -764,15 +1193,25 @@ optimized_attention_masked = optimized_attention
# register core-supported attention functions
if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage)
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
register_attention_function("var_attention_sage", var_attention_sage)
if SAGE_ATTENTION3_IS_AVAILABLE:
register_attention_function("sage3", attention3_sage)
register_attention_function("var_attention_sage3", var_attention_sage3)
if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash)
if FLASH_ATTENTION_VARLEN_IS_AVAILABLE:
register_attention_function("var_attention_flash", var_attention_flash)
if FLASH_ATTENTION3_IS_AVAILABLE:
register_attention_function("var_attention_flash3", var_attention_flash3)
if model_management.xformers_enabled():
register_attention_function("xformers", attention_xformers)
register_attention_function("pytorch", attention_pytorch)
register_attention_function("var_attention_pytorch", var_attention_pytorch)
register_attention_function("sub_quad", attention_sub_quad)
register_attention_function("var_attention_sub_quad", var_attention_sub_quad)
register_attention_function("split", attention_split)
register_attention_function("var_attention_split", var_attention_split)
def optimized_attention_for_device(device, mask=False, small_input=False):
@ -1209,5 +1648,3 @@ class SpatialVideoTransformer(SpatialTransformer):
x = self.proj_out(x)
out = x + x_in
return out

View File

@ -13,6 +13,7 @@ if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
def torch_cat_if_needed(xl, dim):
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
if len(xl) > 1:
@ -22,7 +23,8 @@ def torch_cat_if_needed(xl, dim):
else:
return None
def get_timestep_embedding(timesteps, embedding_dim):
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
@ -33,11 +35,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = math.log(10000) / (half_dim - downscale_freq_shift)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb

0
comfy/samplers.py Executable file → Normal file
View File