mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 09:57:24 +08:00
Add SeedVR2 attention and sampler support
This commit is contained in:
parent
15a500ff6b
commit
2ebacd019d
@ -32,6 +32,14 @@ except ImportError as e:
|
||||
raise e
|
||||
exit(-1)
|
||||
|
||||
SAGE_ATTENTION_VARLEN_IS_AVAILABLE = False
|
||||
try:
|
||||
from sageattention import sageattn_varlen
|
||||
SAGE_ATTENTION_VARLEN_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
if model_management.sage_attention_enabled():
|
||||
logging.warning("SageAttention variable-length attention is unavailable, using pytorch var-len attention instead.")
|
||||
|
||||
SAGE_ATTENTION3_IS_AVAILABLE = False
|
||||
try:
|
||||
from sageattn3 import sageattn3_blackwell
|
||||
@ -40,6 +48,7 @@ except ImportError:
|
||||
pass
|
||||
|
||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||
FLASH_ATTENTION_VARLEN_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
FLASH_ATTENTION_IS_AVAILABLE = True
|
||||
@ -48,6 +57,20 @@ except ImportError:
|
||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||
exit(-1)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
FLASH_ATTENTION_VARLEN_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
if model_management.flash_attention_enabled() and FLASH_ATTENTION_IS_AVAILABLE:
|
||||
logging.warning("Flash Attention variable-length attention is unavailable, using pytorch var-len attention instead.")
|
||||
|
||||
FLASH_ATTENTION3_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn3_varlen_func
|
||||
FLASH_ATTENTION3_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||
def register_attention_function(name: str, func: Callable):
|
||||
# avoid replacing existing functions
|
||||
@ -735,28 +758,434 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
)
|
||||
return out
|
||||
|
||||
_VAR_ATTENTION_NESTED_API_NAME = "nested_tensor_from_jagged"
|
||||
_VAR_ATTENTION_GUARD_MESSAGE = (
|
||||
"SeedVR2 var_attention_pytorch: torch.nested.nested_tensor_from_jagged "
|
||||
"is required by this attention path; the installed PyTorch build "
|
||||
"does not provide it"
|
||||
)
|
||||
|
||||
|
||||
def _var_attention_max_seqlen(cu_seqlens):
|
||||
return int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item())
|
||||
|
||||
|
||||
def _var_attention_qkv(q, k, v, heads, skip_reshape):
|
||||
if skip_reshape:
|
||||
return q, k, v, q.shape[-1]
|
||||
total_tokens, embed_dim = q.shape
|
||||
head_dim = embed_dim // heads
|
||||
return (
|
||||
q.view(total_tokens, heads, head_dim),
|
||||
k.view(k.shape[0], heads, head_dim),
|
||||
v.view(v.shape[0], heads, head_dim),
|
||||
head_dim,
|
||||
)
|
||||
|
||||
|
||||
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
|
||||
if skip_output_reshape:
|
||||
return out
|
||||
return out.reshape(-1, heads * head_dim)
|
||||
|
||||
|
||||
def _use_blackwell_attention():
|
||||
device = model_management.get_torch_device()
|
||||
if device.type != "cuda":
|
||||
return False
|
||||
major, minor = torch.cuda.get_device_capability(device)
|
||||
return (major, minor) >= (12, 0)
|
||||
|
||||
|
||||
def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
||||
_nested = getattr(torch, "nested", None)
|
||||
if _nested is None or not hasattr(_nested, _VAR_ATTENTION_NESTED_API_NAME):
|
||||
raise RuntimeError(_VAR_ATTENTION_GUARD_MESSAGE)
|
||||
|
||||
if not skip_reshape:
|
||||
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||
total_tokens, embed_dim = q.shape
|
||||
head_dim = embed_dim // heads
|
||||
q = q.view(total_tokens, heads, head_dim)
|
||||
k = k.view(k.shape[0], heads, head_dim)
|
||||
v = v.view(v.shape[0], heads, head_dim)
|
||||
|
||||
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
|
||||
out = out.transpose(1, 2)
|
||||
if not skip_output_reshape:
|
||||
return out.values().reshape(-1, heads * (q.shape[-1]))
|
||||
return out.values()
|
||||
|
||||
|
||||
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
|
||||
if cu_seqlens.dtype not in (torch.int32, torch.int64):
|
||||
raise ValueError(f"{name} must use an integer dtype")
|
||||
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
|
||||
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
|
||||
if cu_seqlens[0].item() != 0:
|
||||
raise ValueError(f"{name} must start at 0")
|
||||
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
|
||||
raise ValueError(f"{name} must be strictly increasing")
|
||||
if cu_seqlens[-1].item() != token_count:
|
||||
raise ValueError(f"{name} does not match token count")
|
||||
|
||||
|
||||
def _split_indices(cu_seqlens):
|
||||
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
|
||||
|
||||
|
||||
def var_attention_pytorch_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||
|
||||
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
|
||||
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
|
||||
if cu_seqlens_k[-1].item() != v.shape[0]:
|
||||
raise ValueError("cu_seqlens_k does not match v token count")
|
||||
|
||||
q_split_indices = _split_indices(cu_seqlens_q)
|
||||
k_split_indices = _split_indices(cu_seqlens_k)
|
||||
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
|
||||
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
|
||||
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
|
||||
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
|
||||
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
|
||||
|
||||
out = []
|
||||
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
|
||||
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
|
||||
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
|
||||
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
|
||||
out_i = comfy.ops.scaled_dot_product_attention(q_i, k_i, v_i, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out.append(out_i.squeeze(0).permute(1, 0, 2))
|
||||
|
||||
out = torch.cat(out, dim=0)
|
||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||
|
||||
@torch._dynamo.disable
|
||||
def var_attention_sage(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if not SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||
return var_attention_pytorch(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||
out_dtype = q.dtype
|
||||
if not (q.dtype == k.dtype == v.dtype):
|
||||
k = k.to(q.dtype)
|
||||
v = v.to(q.dtype)
|
||||
fallback_q, fallback_k, fallback_v = q, k, v
|
||||
if q.dtype not in (torch.float16, torch.bfloat16):
|
||||
q = q.to(torch.bfloat16)
|
||||
k = k.to(torch.bfloat16)
|
||||
v = v.to(torch.bfloat16)
|
||||
sm_scale = kwargs.get("softmax_scale")
|
||||
if sm_scale is None:
|
||||
sm_scale = 1.0 / math.sqrt(head_dim)
|
||||
try:
|
||||
out = sageattn_varlen(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q.int(),
|
||||
cu_seqlens_k.int(),
|
||||
_var_attention_max_seqlen(cu_seqlens_q),
|
||||
_var_attention_max_seqlen(cu_seqlens_k),
|
||||
kwargs.get("causal", False),
|
||||
sm_scale,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Error running sage var-len attention: %s, using pytorch var-len attention instead.", e)
|
||||
out = var_attention_pytorch(
|
||||
fallback_q,
|
||||
fallback_k,
|
||||
fallback_v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
if out.dtype != out_dtype:
|
||||
out = out.to(out_dtype)
|
||||
return out
|
||||
if out.dtype != out_dtype:
|
||||
out = out.to(out_dtype)
|
||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def var_attention_sage3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if not SAGE_ATTENTION3_IS_AVAILABLE:
|
||||
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||
return var_attention_sage(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs,
|
||||
)
|
||||
return var_attention_pytorch(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||
seq_lens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
|
||||
seq_lens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
|
||||
uniform_q = bool((seq_lens_q == seq_lens_q[0]).all().item())
|
||||
uniform_k = bool((seq_lens_k == seq_lens_k[0]).all().item())
|
||||
if not (uniform_q and uniform_k and seq_lens_q[0] == seq_lens_k[0]):
|
||||
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||
return var_attention_sage(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs,
|
||||
)
|
||||
return var_attention_pytorch(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
out_dtype = q.dtype
|
||||
if not (q.dtype == k.dtype == v.dtype):
|
||||
k = k.to(q.dtype)
|
||||
v = v.to(q.dtype)
|
||||
fallback_q, fallback_k, fallback_v = q, k, v
|
||||
if q.dtype not in (torch.float16, torch.bfloat16):
|
||||
q = q.to(torch.bfloat16)
|
||||
k = k.to(torch.bfloat16)
|
||||
v = v.to(torch.bfloat16)
|
||||
batch_size = len(cu_seqlens_q) - 1
|
||||
seq_len = int(seq_lens_q[0].item())
|
||||
q = q.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
|
||||
try:
|
||||
out = sageattn3_blackwell(q, k, v, is_causal=kwargs.get("causal", False))
|
||||
except Exception as e:
|
||||
logging.error("Error running SageAttention3 var-len attention: %s, using fallback var-len attention instead.", e)
|
||||
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||
return var_attention_sage(
|
||||
fallback_q,
|
||||
fallback_k,
|
||||
fallback_v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs,
|
||||
)
|
||||
return var_attention_pytorch(
|
||||
fallback_q,
|
||||
fallback_k,
|
||||
fallback_v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
out = out.transpose(1, 2).reshape(-1, heads, head_dim).contiguous()
|
||||
if out.dtype != out_dtype:
|
||||
out = out.to(out_dtype)
|
||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def var_attention_flash(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if not FLASH_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||
return var_attention_pytorch(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||
max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q)
|
||||
max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k)
|
||||
try:
|
||||
out = flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=cu_seqlens_q.int(),
|
||||
cu_seqlens_k=cu_seqlens_k.int(),
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
dropout_p=kwargs.get("dropout_p", 0.0),
|
||||
causal=kwargs.get("causal", False),
|
||||
deterministic=torch.are_deterministic_algorithms_enabled(),
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Error running Flash Attention var-len attention: %s, using pytorch var-len attention instead.", e)
|
||||
return var_attention_pytorch(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def var_attention_flash3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if not FLASH_ATTENTION3_IS_AVAILABLE:
|
||||
return var_attention_pytorch(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||
max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q)
|
||||
max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k)
|
||||
try:
|
||||
out = flash_attn3_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=cu_seqlens_q.int(),
|
||||
cu_seqlens_k=cu_seqlens_k.int(),
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
seqused_q=None,
|
||||
seqused_k=None,
|
||||
softmax_scale=kwargs.get("softmax_scale"),
|
||||
causal=kwargs.get("causal", False),
|
||||
deterministic=torch.are_deterministic_algorithms_enabled(),
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Error running Flash Attention 3 var-len attention: %s, using pytorch var-len attention instead.", e)
|
||||
return var_attention_pytorch(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def var_attention_sub_quad(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
return var_attention_pytorch(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def var_attention_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
return var_attention_pytorch_split(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
|
||||
|
||||
optimized_var_attention = var_attention_pytorch
|
||||
optimized_attention = attention_basic
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
logging.info("Using sage attention")
|
||||
optimized_attention = attention_sage
|
||||
if SAGE_ATTENTION3_IS_AVAILABLE and _use_blackwell_attention():
|
||||
logging.info("Using SageAttention3 for variable-length attention")
|
||||
optimized_var_attention = var_attention_sage3
|
||||
elif SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||
logging.info("Using SageAttention for variable-length attention")
|
||||
optimized_var_attention = var_attention_sage
|
||||
else:
|
||||
logging.info("Using pytorch attention for variable-length attention")
|
||||
optimized_var_attention = var_attention_pytorch
|
||||
elif model_management.flash_attention_enabled():
|
||||
logging.info("Using Flash Attention")
|
||||
optimized_attention = attention_flash
|
||||
if FLASH_ATTENTION_VARLEN_IS_AVAILABLE and model_management.get_torch_device().type == "cuda":
|
||||
logging.info("Using Flash Attention 2 for variable-length attention")
|
||||
optimized_var_attention = var_attention_flash
|
||||
else:
|
||||
logging.info("Using pytorch attention for variable-length attention")
|
||||
optimized_var_attention = var_attention_pytorch
|
||||
elif model_management.xformers_enabled():
|
||||
logging.info("Using xformers attention")
|
||||
optimized_attention = attention_xformers
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch attention")
|
||||
optimized_attention = attention_pytorch
|
||||
optimized_var_attention = var_attention_pytorch
|
||||
else:
|
||||
if args.use_split_cross_attention:
|
||||
logging.info("Using split optimization for attention")
|
||||
optimized_attention = attention_split
|
||||
optimized_var_attention = var_attention_split
|
||||
else:
|
||||
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
optimized_attention = attention_sub_quad
|
||||
optimized_var_attention = var_attention_sub_quad
|
||||
|
||||
optimized_attention_masked = optimized_attention
|
||||
|
||||
@ -764,15 +1193,25 @@ optimized_attention_masked = optimized_attention
|
||||
# register core-supported attention functions
|
||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("sage", attention_sage)
|
||||
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||
register_attention_function("var_attention_sage", var_attention_sage)
|
||||
if SAGE_ATTENTION3_IS_AVAILABLE:
|
||||
register_attention_function("sage3", attention3_sage)
|
||||
register_attention_function("var_attention_sage3", var_attention_sage3)
|
||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("flash", attention_flash)
|
||||
if FLASH_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||
register_attention_function("var_attention_flash", var_attention_flash)
|
||||
if FLASH_ATTENTION3_IS_AVAILABLE:
|
||||
register_attention_function("var_attention_flash3", var_attention_flash3)
|
||||
if model_management.xformers_enabled():
|
||||
register_attention_function("xformers", attention_xformers)
|
||||
register_attention_function("pytorch", attention_pytorch)
|
||||
register_attention_function("var_attention_pytorch", var_attention_pytorch)
|
||||
register_attention_function("sub_quad", attention_sub_quad)
|
||||
register_attention_function("var_attention_sub_quad", var_attention_sub_quad)
|
||||
register_attention_function("split", attention_split)
|
||||
register_attention_function("var_attention_split", var_attention_split)
|
||||
|
||||
|
||||
def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||
@ -1209,5 +1648,3 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
x = self.proj_out(x)
|
||||
out = x + x_in
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ if model_management.xformers_enabled_vae():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
|
||||
def torch_cat_if_needed(xl, dim):
|
||||
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
|
||||
if len(xl) > 1:
|
||||
@ -22,7 +23,8 @@ def torch_cat_if_needed(xl, dim):
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
@ -33,11 +35,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = math.log(10000) / (half_dim - downscale_freq_shift)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||
return emb
|
||||
|
||||
0
comfy/samplers.py
Executable file → Normal file
0
comfy/samplers.py
Executable file → Normal file
Loading…
Reference in New Issue
Block a user