mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-29 19:29:24 +08:00
Replace SeedVR2 custom varlen attention backends and fix 7B RoPE
This commit is contained in:
parent
22078c799b
commit
529b9232f0
@ -32,14 +32,6 @@ except ImportError as e:
|
|||||||
raise e
|
raise e
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
SAGE_ATTENTION_VARLEN_IS_AVAILABLE = False
|
|
||||||
try:
|
|
||||||
from sageattention import sageattn_varlen
|
|
||||||
SAGE_ATTENTION_VARLEN_IS_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
if model_management.sage_attention_enabled():
|
|
||||||
logging.warning("SageAttention variable-length attention is unavailable, using pytorch var-len attention instead.")
|
|
||||||
|
|
||||||
SAGE_ATTENTION3_IS_AVAILABLE = False
|
SAGE_ATTENTION3_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from sageattn3 import sageattn3_blackwell
|
from sageattn3 import sageattn3_blackwell
|
||||||
@ -48,7 +40,6 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||||
FLASH_ATTENTION_VARLEN_IS_AVAILABLE = False
|
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_func
|
from flash_attn import flash_attn_func
|
||||||
FLASH_ATTENTION_IS_AVAILABLE = True
|
FLASH_ATTENTION_IS_AVAILABLE = True
|
||||||
@ -57,20 +48,6 @@ except ImportError:
|
|||||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
try:
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
FLASH_ATTENTION_VARLEN_IS_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
if model_management.flash_attention_enabled() and FLASH_ATTENTION_IS_AVAILABLE:
|
|
||||||
logging.warning("Flash Attention variable-length attention is unavailable, using pytorch var-len attention instead.")
|
|
||||||
|
|
||||||
FLASH_ATTENTION3_IS_AVAILABLE = False
|
|
||||||
try:
|
|
||||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn3_varlen_func
|
|
||||||
FLASH_ATTENTION3_IS_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
REGISTERED_ATTENTION_FUNCTIONS = {}
|
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||||
def register_attention_function(name: str, func: Callable):
|
def register_attention_function(name: str, func: Callable):
|
||||||
# avoid replacing existing functions
|
# avoid replacing existing functions
|
||||||
@ -758,18 +735,6 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
_VAR_ATTENTION_NESTED_API_NAME = "nested_tensor_from_jagged"
|
|
||||||
_VAR_ATTENTION_GUARD_MESSAGE = (
|
|
||||||
"SeedVR2 var_attention_pytorch: torch.nested.nested_tensor_from_jagged "
|
|
||||||
"is required by this attention path; the installed PyTorch build "
|
|
||||||
"does not provide it"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _var_attention_max_seqlen(cu_seqlens):
|
|
||||||
return int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item())
|
|
||||||
|
|
||||||
|
|
||||||
def _var_attention_qkv(q, k, v, heads, skip_reshape):
|
def _var_attention_qkv(q, k, v, heads, skip_reshape):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
return q, k, v, q.shape[-1]
|
return q, k, v, q.shape[-1]
|
||||||
@ -797,34 +762,6 @@ def _use_blackwell_attention():
|
|||||||
return (major, minor) >= (12, 0)
|
return (major, minor) >= (12, 0)
|
||||||
|
|
||||||
|
|
||||||
def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
|
||||||
_nested = getattr(torch, "nested", None)
|
|
||||||
if _nested is None or not hasattr(_nested, _VAR_ATTENTION_NESTED_API_NAME):
|
|
||||||
raise RuntimeError(_VAR_ATTENTION_GUARD_MESSAGE)
|
|
||||||
|
|
||||||
if not skip_reshape:
|
|
||||||
# assumes 2D q, k,v [total_tokens, embed_dim]
|
|
||||||
total_tokens, embed_dim = q.shape
|
|
||||||
head_dim = embed_dim // heads
|
|
||||||
q = q.view(total_tokens, heads, head_dim)
|
|
||||||
k = k.view(k.shape[0], heads, head_dim)
|
|
||||||
v = v.view(v.shape[0], heads, head_dim)
|
|
||||||
|
|
||||||
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
|
||||||
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
|
||||||
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
|
||||||
|
|
||||||
q = q.transpose(1, 2)
|
|
||||||
k = k.transpose(1, 2)
|
|
||||||
v = v.transpose(1, 2)
|
|
||||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
||||||
|
|
||||||
out = out.transpose(1, 2)
|
|
||||||
if not skip_output_reshape:
|
|
||||||
return out.values().reshape(-1, heads * (q.shape[-1]))
|
|
||||||
return out.values()
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
|
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
|
||||||
if cu_seqlens.dtype not in (torch.int32, torch.int64):
|
if cu_seqlens.dtype not in (torch.int32, torch.int64):
|
||||||
raise ValueError(f"{name} must use an integer dtype")
|
raise ValueError(f"{name} must use an integer dtype")
|
||||||
@ -842,7 +779,7 @@ def _split_indices(cu_seqlens):
|
|||||||
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
|
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
|
||||||
|
|
||||||
|
|
||||||
def var_attention_pytorch_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||||
|
|
||||||
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
|
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
|
||||||
@ -863,329 +800,44 @@ def var_attention_pytorch_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip
|
|||||||
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
|
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
|
||||||
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
|
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
|
||||||
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
|
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
|
||||||
out_i = comfy.ops.scaled_dot_product_attention(q_i, k_i, v_i, attn_mask=None, dropout_p=0.0, is_causal=False)
|
out_dtype = q_i.dtype
|
||||||
|
if optimized_attention is attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
|
||||||
|
q_i = q_i.to(torch.bfloat16)
|
||||||
|
k_i = k_i.to(torch.bfloat16)
|
||||||
|
v_i = v_i.to(torch.bfloat16)
|
||||||
|
out_i = optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
if out_i.dtype != out_dtype:
|
||||||
|
out_i = out_i.to(out_dtype)
|
||||||
out.append(out_i.squeeze(0).permute(1, 0, 2))
|
out.append(out_i.squeeze(0).permute(1, 0, 2))
|
||||||
|
|
||||||
out = torch.cat(out, dim=0)
|
out = torch.cat(out, dim=0)
|
||||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||||
|
|
||||||
@torch._dynamo.disable
|
|
||||||
def var_attention_sage(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
|
||||||
if not SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
|
||||||
return var_attention_pytorch(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
heads,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
skip_reshape=skip_reshape,
|
|
||||||
skip_output_reshape=skip_output_reshape,
|
|
||||||
)
|
|
||||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
|
||||||
out_dtype = q.dtype
|
|
||||||
if not (q.dtype == k.dtype == v.dtype):
|
|
||||||
k = k.to(q.dtype)
|
|
||||||
v = v.to(q.dtype)
|
|
||||||
fallback_q, fallback_k, fallback_v = q, k, v
|
|
||||||
if q.dtype not in (torch.float16, torch.bfloat16):
|
|
||||||
q = q.to(torch.bfloat16)
|
|
||||||
k = k.to(torch.bfloat16)
|
|
||||||
v = v.to(torch.bfloat16)
|
|
||||||
sm_scale = kwargs.get("softmax_scale")
|
|
||||||
if sm_scale is None:
|
|
||||||
sm_scale = 1.0 / math.sqrt(head_dim)
|
|
||||||
try:
|
|
||||||
out = sageattn_varlen(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
cu_seqlens_q.int(),
|
|
||||||
cu_seqlens_k.int(),
|
|
||||||
_var_attention_max_seqlen(cu_seqlens_q),
|
|
||||||
_var_attention_max_seqlen(cu_seqlens_k),
|
|
||||||
kwargs.get("causal", False),
|
|
||||||
sm_scale,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("Error running sage var-len attention: %s, using pytorch var-len attention instead.", e)
|
|
||||||
out = var_attention_pytorch(
|
|
||||||
fallback_q,
|
|
||||||
fallback_k,
|
|
||||||
fallback_v,
|
|
||||||
heads,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
skip_reshape=True,
|
|
||||||
skip_output_reshape=skip_output_reshape,
|
|
||||||
)
|
|
||||||
if out.dtype != out_dtype:
|
|
||||||
out = out.to(out_dtype)
|
|
||||||
return out
|
|
||||||
if out.dtype != out_dtype:
|
|
||||||
out = out.to(out_dtype)
|
|
||||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
|
||||||
|
|
||||||
|
optimized_var_attention = var_attention_optimized_split
|
||||||
@torch._dynamo.disable
|
|
||||||
def var_attention_sage3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
|
||||||
if not SAGE_ATTENTION3_IS_AVAILABLE:
|
|
||||||
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
|
||||||
return var_attention_sage(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
heads,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
skip_reshape=skip_reshape,
|
|
||||||
skip_output_reshape=skip_output_reshape,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return var_attention_pytorch(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
heads,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
skip_reshape=skip_reshape,
|
|
||||||
skip_output_reshape=skip_output_reshape,
|
|
||||||
)
|
|
||||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
|
||||||
seq_lens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
|
|
||||||
seq_lens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
|
|
||||||
uniform_q = bool((seq_lens_q == seq_lens_q[0]).all().item())
|
|
||||||
uniform_k = bool((seq_lens_k == seq_lens_k[0]).all().item())
|
|
||||||
if not (uniform_q and uniform_k and seq_lens_q[0] == seq_lens_k[0]):
|
|
||||||
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
|
||||||
return var_attention_sage(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
heads,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
skip_reshape=True,
|
|
||||||
skip_output_reshape=skip_output_reshape,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return var_attention_pytorch(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
heads,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
skip_reshape=True,
|
|
||||||
skip_output_reshape=skip_output_reshape,
|
|
||||||
)
|
|
||||||
out_dtype = q.dtype
|
|
||||||
if not (q.dtype == k.dtype == v.dtype):
|
|
||||||
k = k.to(q.dtype)
|
|
||||||
v = v.to(q.dtype)
|
|
||||||
fallback_q, fallback_k, fallback_v = q, k, v
|
|
||||||
if q.dtype not in (torch.float16, torch.bfloat16):
|
|
||||||
q = q.to(torch.bfloat16)
|
|
||||||
k = k.to(torch.bfloat16)
|
|
||||||
v = v.to(torch.bfloat16)
|
|
||||||
batch_size = len(cu_seqlens_q) - 1
|
|
||||||
seq_len = int(seq_lens_q[0].item())
|
|
||||||
q = q.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
|
|
||||||
k = k.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
|
|
||||||
v = v.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
|
|
||||||
try:
|
|
||||||
out = sageattn3_blackwell(q, k, v, is_causal=kwargs.get("causal", False))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("Error running SageAttention3 var-len attention: %s, using fallback var-len attention instead.", e)
|
|
||||||
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
|
||||||
return var_attention_sage(
|
|
||||||
fallback_q,
|
|
||||||
fallback_k,
|
|
||||||
fallback_v,
|
|
||||||
heads,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
skip_reshape=True,
|
|
||||||
skip_output_reshape=skip_output_reshape,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return var_attention_pytorch(
|
|
||||||
fallback_q,
|
|
||||||
fallback_k,
|
|
||||||
fallback_v,
|
|
||||||
heads,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
skip_reshape=True,
|
|
||||||
skip_output_reshape=skip_output_reshape,
|
|
||||||
)
|
|
||||||
out = out.transpose(1, 2).reshape(-1, heads, head_dim).contiguous()
|
|
||||||
if out.dtype != out_dtype:
|
|
||||||
out = out.to(out_dtype)
|
|
||||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
|
||||||
|
|
||||||
|
|
||||||
@torch._dynamo.disable
|
|
||||||
def var_attention_flash(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
|
||||||
if not FLASH_ATTENTION_VARLEN_IS_AVAILABLE:
|
|
||||||
return var_attention_pytorch(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
heads,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
skip_reshape=skip_reshape,
|
|
||||||
skip_output_reshape=skip_output_reshape,
|
|
||||||
)
|
|
||||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
|
||||||
max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q)
|
|
||||||
max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k)
|
|
||||||
try:
|
|
||||||
out = flash_attn_varlen_func(
|
|
||||||
q=q,
|
|
||||||
k=k,
|
|
||||||
v=v,
|
|
||||||
cu_seqlens_q=cu_seqlens_q.int(),
|
|
||||||
cu_seqlens_k=cu_seqlens_k.int(),
|
|
||||||
max_seqlen_q=max_seqlen_q,
|
|
||||||
max_seqlen_k=max_seqlen_k,
|
|
||||||
dropout_p=kwargs.get("dropout_p", 0.0),
|
|
||||||
causal=kwargs.get("causal", False),
|
|
||||||
deterministic=torch.are_deterministic_algorithms_enabled(),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("Error running Flash Attention var-len attention: %s, using pytorch var-len attention instead.", e)
|
|
||||||
return var_attention_pytorch(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
heads,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
skip_reshape=True,
|
|
||||||
skip_output_reshape=skip_output_reshape,
|
|
||||||
)
|
|
||||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
|
||||||
|
|
||||||
|
|
||||||
@torch._dynamo.disable
|
|
||||||
def var_attention_flash3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
|
||||||
if not FLASH_ATTENTION3_IS_AVAILABLE:
|
|
||||||
return var_attention_pytorch(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
heads,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
skip_reshape=skip_reshape,
|
|
||||||
skip_output_reshape=skip_output_reshape,
|
|
||||||
)
|
|
||||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
|
||||||
max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q)
|
|
||||||
max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k)
|
|
||||||
try:
|
|
||||||
out = flash_attn3_varlen_func(
|
|
||||||
q=q,
|
|
||||||
k=k,
|
|
||||||
v=v,
|
|
||||||
cu_seqlens_q=cu_seqlens_q.int(),
|
|
||||||
cu_seqlens_k=cu_seqlens_k.int(),
|
|
||||||
max_seqlen_q=max_seqlen_q,
|
|
||||||
max_seqlen_k=max_seqlen_k,
|
|
||||||
seqused_q=None,
|
|
||||||
seqused_k=None,
|
|
||||||
softmax_scale=kwargs.get("softmax_scale"),
|
|
||||||
causal=kwargs.get("causal", False),
|
|
||||||
deterministic=torch.are_deterministic_algorithms_enabled(),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("Error running Flash Attention 3 var-len attention: %s, using pytorch var-len attention instead.", e)
|
|
||||||
return var_attention_pytorch(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
heads,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
skip_reshape=True,
|
|
||||||
skip_output_reshape=skip_output_reshape,
|
|
||||||
)
|
|
||||||
if isinstance(out, tuple):
|
|
||||||
out = out[0]
|
|
||||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
|
||||||
|
|
||||||
|
|
||||||
@torch._dynamo.disable
|
|
||||||
def var_attention_sub_quad(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
|
||||||
return var_attention_pytorch(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
heads,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
skip_reshape=skip_reshape,
|
|
||||||
skip_output_reshape=skip_output_reshape,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch._dynamo.disable
|
|
||||||
def var_attention_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
|
||||||
return var_attention_pytorch_split(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
heads,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
skip_reshape=skip_reshape,
|
|
||||||
skip_output_reshape=skip_output_reshape,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
optimized_var_attention = var_attention_pytorch
|
|
||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
|
||||||
if model_management.sage_attention_enabled():
|
if model_management.sage_attention_enabled():
|
||||||
logging.info("Using sage attention")
|
logging.info("Using sage attention")
|
||||||
optimized_attention = attention_sage
|
optimized_attention = attention_sage
|
||||||
if SAGE_ATTENTION3_IS_AVAILABLE and _use_blackwell_attention():
|
|
||||||
logging.info("Using SageAttention3 for variable-length attention")
|
|
||||||
optimized_var_attention = var_attention_sage3
|
|
||||||
elif SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
|
||||||
logging.info("Using SageAttention for variable-length attention")
|
|
||||||
optimized_var_attention = var_attention_sage
|
|
||||||
else:
|
|
||||||
logging.info("Using pytorch attention for variable-length attention")
|
|
||||||
optimized_var_attention = var_attention_pytorch
|
|
||||||
elif model_management.flash_attention_enabled():
|
elif model_management.flash_attention_enabled():
|
||||||
logging.info("Using Flash Attention")
|
logging.info("Using Flash Attention")
|
||||||
optimized_attention = attention_flash
|
optimized_attention = attention_flash
|
||||||
if FLASH_ATTENTION_VARLEN_IS_AVAILABLE and model_management.get_torch_device().type == "cuda":
|
|
||||||
logging.info("Using Flash Attention 2 for variable-length attention")
|
|
||||||
optimized_var_attention = var_attention_flash
|
|
||||||
else:
|
|
||||||
logging.info("Using pytorch attention for variable-length attention")
|
|
||||||
optimized_var_attention = var_attention_pytorch
|
|
||||||
elif model_management.xformers_enabled():
|
elif model_management.xformers_enabled():
|
||||||
logging.info("Using xformers attention")
|
logging.info("Using xformers attention")
|
||||||
optimized_attention = attention_xformers
|
optimized_attention = attention_xformers
|
||||||
elif model_management.pytorch_attention_enabled():
|
elif model_management.pytorch_attention_enabled():
|
||||||
logging.info("Using pytorch attention")
|
logging.info("Using pytorch attention")
|
||||||
optimized_attention = attention_pytorch
|
optimized_attention = attention_pytorch
|
||||||
optimized_var_attention = var_attention_pytorch
|
|
||||||
else:
|
else:
|
||||||
if args.use_split_cross_attention:
|
if args.use_split_cross_attention:
|
||||||
logging.info("Using split optimization for attention")
|
logging.info("Using split optimization for attention")
|
||||||
optimized_attention = attention_split
|
optimized_attention = attention_split
|
||||||
optimized_var_attention = var_attention_split
|
|
||||||
else:
|
else:
|
||||||
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||||
optimized_attention = attention_sub_quad
|
optimized_attention = attention_sub_quad
|
||||||
optimized_var_attention = var_attention_sub_quad
|
|
||||||
|
logging.info("Using optimized_attention split-loop for variable-length attention")
|
||||||
|
|
||||||
optimized_attention_masked = optimized_attention
|
optimized_attention_masked = optimized_attention
|
||||||
|
|
||||||
@ -1193,25 +845,16 @@ optimized_attention_masked = optimized_attention
|
|||||||
# register core-supported attention functions
|
# register core-supported attention functions
|
||||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||||
register_attention_function("sage", attention_sage)
|
register_attention_function("sage", attention_sage)
|
||||||
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
|
||||||
register_attention_function("var_attention_sage", var_attention_sage)
|
|
||||||
if SAGE_ATTENTION3_IS_AVAILABLE:
|
if SAGE_ATTENTION3_IS_AVAILABLE:
|
||||||
register_attention_function("sage3", attention3_sage)
|
register_attention_function("sage3", attention3_sage)
|
||||||
register_attention_function("var_attention_sage3", var_attention_sage3)
|
|
||||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||||
register_attention_function("flash", attention_flash)
|
register_attention_function("flash", attention_flash)
|
||||||
if FLASH_ATTENTION_VARLEN_IS_AVAILABLE:
|
|
||||||
register_attention_function("var_attention_flash", var_attention_flash)
|
|
||||||
if FLASH_ATTENTION3_IS_AVAILABLE:
|
|
||||||
register_attention_function("var_attention_flash3", var_attention_flash3)
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
register_attention_function("xformers", attention_xformers)
|
register_attention_function("xformers", attention_xformers)
|
||||||
register_attention_function("pytorch", attention_pytorch)
|
register_attention_function("pytorch", attention_pytorch)
|
||||||
register_attention_function("var_attention_pytorch", var_attention_pytorch)
|
|
||||||
register_attention_function("sub_quad", attention_sub_quad)
|
register_attention_function("sub_quad", attention_sub_quad)
|
||||||
register_attention_function("var_attention_sub_quad", var_attention_sub_quad)
|
|
||||||
register_attention_function("split", attention_split)
|
register_attention_function("split", attention_split)
|
||||||
register_attention_function("var_attention_split", var_attention_split)
|
register_attention_function("var_attention_optimized_split", var_attention_optimized_split)
|
||||||
|
|
||||||
|
|
||||||
def optimized_attention_for_device(device, mask=False, small_input=False):
|
def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||||
|
|||||||
@ -40,7 +40,7 @@ SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path
|
|||||||
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path.
|
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path.
|
||||||
|
|
||||||
# --------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------
|
||||||
# C. ByteDance config / source (BYTEDANCE - cite myseedvr2/SeedVR)
|
# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR)
|
||||||
# --------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------
|
||||||
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm.
|
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm.
|
||||||
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.
|
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.
|
||||||
|
|||||||
@ -58,10 +58,11 @@ class CustomRMSNorm(nn.Module):
|
|||||||
|
|
||||||
dims = tuple(range(-len(self.normalized_shape), 0))
|
dims = tuple(range(-len(self.normalized_shape), 0))
|
||||||
|
|
||||||
variance = input.pow(2).mean(dim=dims, keepdim=True)
|
normalized = input.float()
|
||||||
|
variance = normalized.pow(2).mean(dim=dims, keepdim=True)
|
||||||
rms = torch.sqrt(variance + self.eps)
|
rms = torch.sqrt(variance + self.eps)
|
||||||
|
|
||||||
normalized = input / rms
|
normalized = normalized / rms
|
||||||
|
|
||||||
if self.elementwise_affine:
|
if self.elementwise_affine:
|
||||||
return normalized * self.weight.to(input.dtype)
|
return normalized * self.weight.to(input.dtype)
|
||||||
@ -472,8 +473,8 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d):
|
|||||||
freqs = freqs.to(device=q.device)
|
freqs = freqs.to(device=q.device)
|
||||||
q = rearrange(q, "L h d -> h L d")
|
q = rearrange(q, "L h d -> h L d")
|
||||||
k = rearrange(k, "L h d -> h L d")
|
k = rearrange(k, "L h d -> h L d")
|
||||||
q = _apply_rope1_partial(q, freqs)
|
q = _apply_seedvr2_rotary_emb(freqs, q.float()).to(q.dtype)
|
||||||
k = _apply_rope1_partial(k, freqs)
|
k = _apply_seedvr2_rotary_emb(freqs, k.float()).to(k.dtype)
|
||||||
q = rearrange(q, "h L d -> L h d")
|
q = rearrange(q, "h L d -> L h d")
|
||||||
k = rearrange(k, "h L d -> L h d")
|
k = rearrange(k, "h L d -> L h d")
|
||||||
return q, k
|
return q, k
|
||||||
@ -483,11 +484,20 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d):
|
|||||||
self,
|
self,
|
||||||
shape: torch.LongTensor,
|
shape: torch.LongTensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# Primary provenance: ByteDance-Seed/SeedVR models/dit/rope.py builds
|
||||||
|
# 7B pixel RoPE with the interleaved-angle convention, not Comfy's
|
||||||
|
# Flux freqs_cis matrix.
|
||||||
|
plain_rope = RotaryEmbedding(
|
||||||
|
dim=self.rope.freqs.numel() * 2,
|
||||||
|
freqs_for="pixel",
|
||||||
|
max_freq=BYTEDANCE_ROPE_MAX_FREQ,
|
||||||
|
)
|
||||||
|
plain_rope = plain_rope.to(self.rope.dummy.device)
|
||||||
freq_list = []
|
freq_list = []
|
||||||
for f, h, w in shape.tolist():
|
for f, h, w in shape.tolist():
|
||||||
freqs = self.get_axial_freqs(f, h, w)
|
freqs = plain_rope.get_axial_freqs(f, h, w)
|
||||||
freq_list.append(freqs.view(-1, freqs.size(-1)))
|
freq_list.append(freqs.view(-1, freqs.size(-1)))
|
||||||
return _to_flux_freqs_cis(torch.cat(freq_list, dim=0))
|
return torch.cat(freq_list, dim=0)
|
||||||
|
|
||||||
|
|
||||||
class MMRotaryEmbeddingBase(RotaryEmbeddingBase):
|
class MMRotaryEmbeddingBase(RotaryEmbeddingBase):
|
||||||
@ -556,6 +566,36 @@ def apply_rotary_emb(
|
|||||||
out = torch.cat((t_left, t_middle_out, t_right), dim=-1)
|
out = torch.cat((t_left, t_middle_out, t_right), dim=-1)
|
||||||
return out.type(dtype)
|
return out.type(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_seedvr2_rotary_emb(
|
||||||
|
freqs: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
start_index: int = 0,
|
||||||
|
scale: float = 1.0,
|
||||||
|
seq_dim: int = -2,
|
||||||
|
freqs_seq_dim: int | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
dtype = t.dtype
|
||||||
|
if freqs_seq_dim is None and (freqs.ndim == 2 or t.ndim == 3):
|
||||||
|
freqs_seq_dim = 0
|
||||||
|
|
||||||
|
if t.ndim == 3 or freqs_seq_dim is not None:
|
||||||
|
seq_len = t.shape[seq_dim]
|
||||||
|
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim)
|
||||||
|
|
||||||
|
rot_feats = freqs.shape[-1]
|
||||||
|
end_index = start_index + rot_feats
|
||||||
|
|
||||||
|
t_left = t[..., :start_index]
|
||||||
|
t_middle = t[..., start_index:end_index]
|
||||||
|
t_right = t[..., end_index:]
|
||||||
|
|
||||||
|
freqs = freqs.to(device=t_middle.device, dtype=t_middle.dtype)
|
||||||
|
cos = freqs.cos() * scale
|
||||||
|
sin = freqs.sin() * scale
|
||||||
|
t_middle = (t_middle * cos) + (rotate_half(t_middle) * sin)
|
||||||
|
return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype)
|
||||||
|
|
||||||
def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
|
def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
|
||||||
"""Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`."""
|
"""Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`."""
|
||||||
angles = freqs_interleaved[..., ::2].float()
|
angles = freqs_interleaved[..., ::2].float()
|
||||||
@ -1380,6 +1420,8 @@ class NaDiT(nn.Module):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM
|
self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM
|
||||||
|
if self._7b_version:
|
||||||
|
rope_type = "rope3d"
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
|
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
|
||||||
|
|||||||
@ -24,12 +24,7 @@ from comfy.ldm.seedvr.constants import (
|
|||||||
BYTEDANCE_VAE_SHIFTING_FACTOR,
|
BYTEDANCE_VAE_SHIFTING_FACTOR,
|
||||||
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE,
|
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE,
|
||||||
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE,
|
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE,
|
||||||
CIELAB_DELTA,
|
|
||||||
CIELAB_KAPPA,
|
|
||||||
D65_WHITE_X,
|
|
||||||
D65_WHITE_Z,
|
|
||||||
SEEDVR2_LATENT_CHANNELS,
|
SEEDVR2_LATENT_CHANNELS,
|
||||||
WAVELET_DECOMP_LEVELS,
|
|
||||||
)
|
)
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||||
|
|||||||
@ -575,7 +575,7 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
description="Build SeedVR2 positive/negative conditioning from a VAE latent.",
|
description="Build SeedVR2 positive/negative conditioning from a VAE latent.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model", tooltip="The SeedVR2 model."),
|
io.Model.Input("model", tooltip="The SeedVR2 model."),
|
||||||
io.Latent.Input("vae_conditioning", tooltip="The VAE-encoded latent to condition on."),
|
io.Latent.Input("vae_conditioning", display_name="LATENT", tooltip="The VAE-encoded latent to condition on."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(display_name = "model"),
|
io.Model.Output(display_name = "model"),
|
||||||
|
|||||||
@ -6,9 +6,7 @@ Sources (all merged verbatim, helper names disambiguated where colliding):
|
|||||||
apply_rotary_emb wrapper oracle at fp32.
|
apply_rotary_emb wrapper oracle at fp32.
|
||||||
* GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare
|
* GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare
|
||||||
memory_occupy against get_norm_limit(), not float('inf').
|
memory_occupy against get_norm_limit(), not float('inf').
|
||||||
* var_attention backend registry.
|
* SeedVR2 variable-length attention split-loop contract.
|
||||||
* var_attention_pytorch SeedVR2-named guard — present-API shape contract
|
|
||||||
with AST-level pinning of the guard ordering.
|
|
||||||
|
|
||||||
Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and
|
Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and
|
||||||
comfy.ldm.modules.attention transitively pull in comfy.model_management,
|
comfy.ldm.modules.attention transitively pull in comfy.model_management,
|
||||||
@ -18,11 +16,6 @@ set first.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import ast
|
|
||||||
import inspect
|
|
||||||
import logging
|
|
||||||
import textwrap
|
|
||||||
import warnings
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -45,7 +38,7 @@ from comfy.ldm.seedvr.vae import ( # noqa: E402
|
|||||||
causal_norm_wrapper,
|
causal_norm_wrapper,
|
||||||
set_norm_limit,
|
set_norm_limit,
|
||||||
)
|
)
|
||||||
from comfy.ldm.modules.attention import var_attention_pytorch # noqa: E402
|
from comfy.ldm.modules.attention import var_attention_optimized_split # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -215,13 +208,14 @@ def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# var_attention backend tests (test_seedvr_var_attention_backends.py)
|
# SeedVR2 var_attention split-loop tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def test_var_attention_registry_contains_always_available_entries():
|
def test_var_attention_registry_contains_always_available_entries():
|
||||||
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_pytorch"] is attention.var_attention_pytorch
|
assert (
|
||||||
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_sub_quad"] is attention.var_attention_sub_quad
|
attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_optimized_split"]
|
||||||
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_split"] is attention.var_attention_split
|
is attention.var_attention_optimized_split
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch):
|
def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch):
|
||||||
@ -285,105 +279,63 @@ def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypa
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch):
|
||||||
# var_attention_pytorch SeedVR2 guard tests
|
heads = 2
|
||||||
# (test_var_attention_pytorch_seedvr2_guard.py)
|
head_dim = 3
|
||||||
# ---------------------------------------------------------------------------
|
q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim)
|
||||||
|
k = q + 100
|
||||||
|
v = q + 200
|
||||||
|
cu = torch.tensor([0, 2, 5], dtype=torch.int32)
|
||||||
|
calls = []
|
||||||
|
|
||||||
def _pytorch_guard_inputs():
|
def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs):
|
||||||
heads, head_dim, total_tokens = 2, 8, 6
|
calls.append(
|
||||||
embed_dim = heads * head_dim
|
{
|
||||||
q = torch.randn(total_tokens, embed_dim)
|
"q_shape": tuple(q_arg.shape),
|
||||||
k = torch.randn(total_tokens, embed_dim)
|
"k_shape": tuple(k_arg.shape),
|
||||||
v = torch.randn(total_tokens, embed_dim)
|
"v_shape": tuple(v_arg.shape),
|
||||||
cu = torch.tensor([0, 3, 6], dtype=torch.int32)
|
"heads": heads_arg,
|
||||||
return q, k, v, heads, cu, cu, total_tokens, embed_dim
|
"kwargs": kwargs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return q_arg + v_arg
|
||||||
|
|
||||||
|
monkeypatch.setattr(attention, "optimized_attention", fake_optimized_attention)
|
||||||
|
|
||||||
def _assert_guard_source_pin():
|
out = var_attention_optimized_split(
|
||||||
src = textwrap.dedent(inspect.getsource(var_attention_pytorch))
|
q,
|
||||||
tree = ast.parse(src)
|
k,
|
||||||
raise_lines = []
|
v,
|
||||||
nested_lines = []
|
heads,
|
||||||
for node in ast.walk(tree):
|
cu,
|
||||||
if isinstance(node, ast.Raise) and isinstance(node.exc, ast.Call):
|
cu,
|
||||||
func = node.exc.func
|
skip_reshape=True,
|
||||||
if isinstance(func, ast.Name) and func.id == "RuntimeError":
|
skip_output_reshape=True,
|
||||||
raise_lines.append(node.lineno)
|
|
||||||
if isinstance(node, ast.Attribute) and node.attr == "nested_tensor_from_jagged":
|
|
||||||
nested_lines.append(node.lineno)
|
|
||||||
assert raise_lines, (
|
|
||||||
"var_attention_pytorch has no `raise RuntimeError(...)` AST node; "
|
|
||||||
f"the SeedVR2-named guard is missing.\n--- source ---\n{src}"
|
|
||||||
)
|
|
||||||
assert nested_lines, (
|
|
||||||
"var_attention_pytorch source has no `nested_tensor_from_jagged` "
|
|
||||||
f"attribute access; cannot pin guard ordering.\n"
|
|
||||||
f"--- source ---\n{src}"
|
|
||||||
)
|
|
||||||
first_raise = min(raise_lines)
|
|
||||||
first_nested = min(nested_lines)
|
|
||||||
assert first_raise < first_nested, (
|
|
||||||
f"`raise RuntimeError(...)` first appears at line {first_raise}, "
|
|
||||||
f"but `torch.nested.nested_tensor_from_jagged` is referenced first "
|
|
||||||
f"at line {first_nested}; the guard must precede the lookup.\n"
|
|
||||||
f"--- source ---\n{src}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert tuple(out.shape) == (5, heads, head_dim)
|
||||||
def test_missing_api_raises_seedvr2_runtime_error(monkeypatch):
|
assert len(calls) == 2
|
||||||
monkeypatch.delattr(torch.nested, "nested_tensor_from_jagged", raising=False)
|
assert calls[0]["q_shape"] == (1, heads, 2, head_dim)
|
||||||
q, k, v, heads, cu_q, cu_k, _, _ = _pytorch_guard_inputs()
|
assert calls[1]["q_shape"] == (1, heads, 3, head_dim)
|
||||||
|
assert all(call["heads"] == heads for call in calls)
|
||||||
with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"):
|
assert all(call["kwargs"]["skip_reshape"] is True for call in calls)
|
||||||
var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
|
assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls)
|
||||||
|
torch.testing.assert_close(out, q + v, rtol=0, atol=0)
|
||||||
_assert_guard_source_pin()
|
|
||||||
|
|
||||||
|
|
||||||
def test_missing_namespace_raises_seedvr2_runtime_error(monkeypatch):
|
def test_var_attention_optimized_split_rejects_bad_offsets():
|
||||||
monkeypatch.delattr(torch, "nested", raising=False)
|
q = torch.randn(5, 2, 3)
|
||||||
q, k, v, heads, cu_q, cu_k, _, _ = _pytorch_guard_inputs()
|
cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32)
|
||||||
|
cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32)
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"):
|
with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"):
|
||||||
var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
|
var_attention_optimized_split(
|
||||||
|
q,
|
||||||
_assert_guard_source_pin()
|
q,
|
||||||
|
q,
|
||||||
|
2,
|
||||||
def test_present_api_returns_expected_shape():
|
cu_bad,
|
||||||
q, k, v, heads, cu_q, cu_k, total_tokens, embed_dim = _pytorch_guard_inputs()
|
cu_ok,
|
||||||
|
skip_reshape=True,
|
||||||
torch_fx_logger = logging.getLogger("torch.fx._symbolic_trace")
|
skip_output_reshape=True,
|
||||||
old_torch_fx_level = torch_fx_logger.level
|
)
|
||||||
torch_fx_logger.setLevel(logging.ERROR)
|
|
||||||
try:
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings(
|
|
||||||
"ignore",
|
|
||||||
message="The PyTorch API of nested tensors is in prototype stage.*",
|
|
||||||
category=UserWarning,
|
|
||||||
)
|
|
||||||
out = var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
|
|
||||||
finally:
|
|
||||||
torch_fx_logger.setLevel(old_torch_fx_level)
|
|
||||||
|
|
||||||
assert tuple(out.shape) == (total_tokens, embed_dim), (
|
|
||||||
f"expected ({total_tokens}, {embed_dim}); got {tuple(out.shape)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
_assert_guard_source_pin()
|
|
||||||
|
|
||||||
|
|
||||||
def test_malformed_offsets_propagates_torch_runtime_error():
|
|
||||||
q, k, v, heads, _, _, _, _ = _pytorch_guard_inputs()
|
|
||||||
cu_q_bad = torch.tensor([0, 3, 7], dtype=torch.int32)
|
|
||||||
cu_k_ok = torch.tensor([0, 3, 6], dtype=torch.int32)
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as exc_info:
|
|
||||||
var_attention_pytorch(q, k, v, heads, cu_q_bad, cu_k_ok)
|
|
||||||
|
|
||||||
msg = str(exc_info.value)
|
|
||||||
assert "SeedVR2" not in msg
|
|
||||||
|
|
||||||
_assert_guard_source_pin()
|
|
||||||
|
|||||||
@ -213,11 +213,11 @@ def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
|
|||||||
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
|
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
|
||||||
freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1)
|
freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1)
|
||||||
|
|
||||||
expected_q = seedvr_model.apply_rotary_emb(
|
expected_q = seedvr_model._apply_seedvr2_rotary_emb(
|
||||||
freqs,
|
freqs,
|
||||||
q.permute(1, 0, 2).float(),
|
q.permute(1, 0, 2).float(),
|
||||||
).to(q.dtype).permute(1, 0, 2)
|
).to(q.dtype).permute(1, 0, 2)
|
||||||
expected_k = seedvr_model.apply_rotary_emb(
|
expected_k = seedvr_model._apply_seedvr2_rotary_emb(
|
||||||
freqs,
|
freqs,
|
||||||
k.permute(1, 0, 2).float(),
|
k.permute(1, 0, 2).float(),
|
||||||
).to(k.dtype).permute(1, 0, 2)
|
).to(k.dtype).permute(1, 0, 2)
|
||||||
|
|||||||
@ -174,10 +174,7 @@ def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
|
|||||||
captured.update(kwargs)
|
captured.update(kwargs)
|
||||||
return torch.zeros(1, 3, 1, 16, 16)
|
return torch.zeros(1, 3, 1, 16, 16)
|
||||||
|
|
||||||
with (
|
with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae):
|
||||||
patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae),
|
|
||||||
patch.object(vae_mod, "lab_color_transfer", side_effect=lambda content, style: content),
|
|
||||||
):
|
|
||||||
wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling)
|
wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling)
|
||||||
|
|
||||||
assert captured["temporal_overlap"] == 7
|
assert captured["temporal_overlap"] == 7
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user