mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 09:49:26 +08:00
Replace SeedVR2 custom varlen attention backends and fix 7B RoPE
This commit is contained in:
parent
22078c799b
commit
529b9232f0
@ -32,14 +32,6 @@ except ImportError as e:
|
||||
raise e
|
||||
exit(-1)
|
||||
|
||||
SAGE_ATTENTION_VARLEN_IS_AVAILABLE = False
|
||||
try:
|
||||
from sageattention import sageattn_varlen
|
||||
SAGE_ATTENTION_VARLEN_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
if model_management.sage_attention_enabled():
|
||||
logging.warning("SageAttention variable-length attention is unavailable, using pytorch var-len attention instead.")
|
||||
|
||||
SAGE_ATTENTION3_IS_AVAILABLE = False
|
||||
try:
|
||||
from sageattn3 import sageattn3_blackwell
|
||||
@ -48,7 +40,6 @@ except ImportError:
|
||||
pass
|
||||
|
||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||
FLASH_ATTENTION_VARLEN_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
FLASH_ATTENTION_IS_AVAILABLE = True
|
||||
@ -57,20 +48,6 @@ except ImportError:
|
||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||
exit(-1)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
FLASH_ATTENTION_VARLEN_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
if model_management.flash_attention_enabled() and FLASH_ATTENTION_IS_AVAILABLE:
|
||||
logging.warning("Flash Attention variable-length attention is unavailable, using pytorch var-len attention instead.")
|
||||
|
||||
FLASH_ATTENTION3_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn3_varlen_func
|
||||
FLASH_ATTENTION3_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||
def register_attention_function(name: str, func: Callable):
|
||||
# avoid replacing existing functions
|
||||
@ -758,18 +735,6 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
)
|
||||
return out
|
||||
|
||||
_VAR_ATTENTION_NESTED_API_NAME = "nested_tensor_from_jagged"
|
||||
_VAR_ATTENTION_GUARD_MESSAGE = (
|
||||
"SeedVR2 var_attention_pytorch: torch.nested.nested_tensor_from_jagged "
|
||||
"is required by this attention path; the installed PyTorch build "
|
||||
"does not provide it"
|
||||
)
|
||||
|
||||
|
||||
def _var_attention_max_seqlen(cu_seqlens):
|
||||
return int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item())
|
||||
|
||||
|
||||
def _var_attention_qkv(q, k, v, heads, skip_reshape):
|
||||
if skip_reshape:
|
||||
return q, k, v, q.shape[-1]
|
||||
@ -797,34 +762,6 @@ def _use_blackwell_attention():
|
||||
return (major, minor) >= (12, 0)
|
||||
|
||||
|
||||
def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
||||
_nested = getattr(torch, "nested", None)
|
||||
if _nested is None or not hasattr(_nested, _VAR_ATTENTION_NESTED_API_NAME):
|
||||
raise RuntimeError(_VAR_ATTENTION_GUARD_MESSAGE)
|
||||
|
||||
if not skip_reshape:
|
||||
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||
total_tokens, embed_dim = q.shape
|
||||
head_dim = embed_dim // heads
|
||||
q = q.view(total_tokens, heads, head_dim)
|
||||
k = k.view(k.shape[0], heads, head_dim)
|
||||
v = v.view(v.shape[0], heads, head_dim)
|
||||
|
||||
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
|
||||
out = out.transpose(1, 2)
|
||||
if not skip_output_reshape:
|
||||
return out.values().reshape(-1, heads * (q.shape[-1]))
|
||||
return out.values()
|
||||
|
||||
|
||||
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
|
||||
if cu_seqlens.dtype not in (torch.int32, torch.int64):
|
||||
raise ValueError(f"{name} must use an integer dtype")
|
||||
@ -842,7 +779,7 @@ def _split_indices(cu_seqlens):
|
||||
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
|
||||
|
||||
|
||||
def var_attention_pytorch_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
||||
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||
|
||||
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
|
||||
@ -863,329 +800,44 @@ def var_attention_pytorch_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip
|
||||
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
|
||||
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
|
||||
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
|
||||
out_i = comfy.ops.scaled_dot_product_attention(q_i, k_i, v_i, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out_dtype = q_i.dtype
|
||||
if optimized_attention is attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
|
||||
q_i = q_i.to(torch.bfloat16)
|
||||
k_i = k_i.to(torch.bfloat16)
|
||||
v_i = v_i.to(torch.bfloat16)
|
||||
out_i = optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
|
||||
if out_i.dtype != out_dtype:
|
||||
out_i = out_i.to(out_dtype)
|
||||
out.append(out_i.squeeze(0).permute(1, 0, 2))
|
||||
|
||||
out = torch.cat(out, dim=0)
|
||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||
|
||||
@torch._dynamo.disable
|
||||
def var_attention_sage(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if not SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||
return var_attention_pytorch(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||
out_dtype = q.dtype
|
||||
if not (q.dtype == k.dtype == v.dtype):
|
||||
k = k.to(q.dtype)
|
||||
v = v.to(q.dtype)
|
||||
fallback_q, fallback_k, fallback_v = q, k, v
|
||||
if q.dtype not in (torch.float16, torch.bfloat16):
|
||||
q = q.to(torch.bfloat16)
|
||||
k = k.to(torch.bfloat16)
|
||||
v = v.to(torch.bfloat16)
|
||||
sm_scale = kwargs.get("softmax_scale")
|
||||
if sm_scale is None:
|
||||
sm_scale = 1.0 / math.sqrt(head_dim)
|
||||
try:
|
||||
out = sageattn_varlen(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q.int(),
|
||||
cu_seqlens_k.int(),
|
||||
_var_attention_max_seqlen(cu_seqlens_q),
|
||||
_var_attention_max_seqlen(cu_seqlens_k),
|
||||
kwargs.get("causal", False),
|
||||
sm_scale,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Error running sage var-len attention: %s, using pytorch var-len attention instead.", e)
|
||||
out = var_attention_pytorch(
|
||||
fallback_q,
|
||||
fallback_k,
|
||||
fallback_v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
if out.dtype != out_dtype:
|
||||
out = out.to(out_dtype)
|
||||
return out
|
||||
if out.dtype != out_dtype:
|
||||
out = out.to(out_dtype)
|
||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def var_attention_sage3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if not SAGE_ATTENTION3_IS_AVAILABLE:
|
||||
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||
return var_attention_sage(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs,
|
||||
)
|
||||
return var_attention_pytorch(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||
seq_lens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
|
||||
seq_lens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
|
||||
uniform_q = bool((seq_lens_q == seq_lens_q[0]).all().item())
|
||||
uniform_k = bool((seq_lens_k == seq_lens_k[0]).all().item())
|
||||
if not (uniform_q and uniform_k and seq_lens_q[0] == seq_lens_k[0]):
|
||||
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||
return var_attention_sage(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs,
|
||||
)
|
||||
return var_attention_pytorch(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
out_dtype = q.dtype
|
||||
if not (q.dtype == k.dtype == v.dtype):
|
||||
k = k.to(q.dtype)
|
||||
v = v.to(q.dtype)
|
||||
fallback_q, fallback_k, fallback_v = q, k, v
|
||||
if q.dtype not in (torch.float16, torch.bfloat16):
|
||||
q = q.to(torch.bfloat16)
|
||||
k = k.to(torch.bfloat16)
|
||||
v = v.to(torch.bfloat16)
|
||||
batch_size = len(cu_seqlens_q) - 1
|
||||
seq_len = int(seq_lens_q[0].item())
|
||||
q = q.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
|
||||
try:
|
||||
out = sageattn3_blackwell(q, k, v, is_causal=kwargs.get("causal", False))
|
||||
except Exception as e:
|
||||
logging.error("Error running SageAttention3 var-len attention: %s, using fallback var-len attention instead.", e)
|
||||
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||
return var_attention_sage(
|
||||
fallback_q,
|
||||
fallback_k,
|
||||
fallback_v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs,
|
||||
)
|
||||
return var_attention_pytorch(
|
||||
fallback_q,
|
||||
fallback_k,
|
||||
fallback_v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
out = out.transpose(1, 2).reshape(-1, heads, head_dim).contiguous()
|
||||
if out.dtype != out_dtype:
|
||||
out = out.to(out_dtype)
|
||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def var_attention_flash(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if not FLASH_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||
return var_attention_pytorch(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||
max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q)
|
||||
max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k)
|
||||
try:
|
||||
out = flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=cu_seqlens_q.int(),
|
||||
cu_seqlens_k=cu_seqlens_k.int(),
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
dropout_p=kwargs.get("dropout_p", 0.0),
|
||||
causal=kwargs.get("causal", False),
|
||||
deterministic=torch.are_deterministic_algorithms_enabled(),
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Error running Flash Attention var-len attention: %s, using pytorch var-len attention instead.", e)
|
||||
return var_attention_pytorch(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def var_attention_flash3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if not FLASH_ATTENTION3_IS_AVAILABLE:
|
||||
return var_attention_pytorch(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||
max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q)
|
||||
max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k)
|
||||
try:
|
||||
out = flash_attn3_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=cu_seqlens_q.int(),
|
||||
cu_seqlens_k=cu_seqlens_k.int(),
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
seqused_q=None,
|
||||
seqused_k=None,
|
||||
softmax_scale=kwargs.get("softmax_scale"),
|
||||
causal=kwargs.get("causal", False),
|
||||
deterministic=torch.are_deterministic_algorithms_enabled(),
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Error running Flash Attention 3 var-len attention: %s, using pytorch var-len attention instead.", e)
|
||||
return var_attention_pytorch(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def var_attention_sub_quad(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
return var_attention_pytorch(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def var_attention_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
return var_attention_pytorch_split(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
)
|
||||
|
||||
|
||||
optimized_var_attention = var_attention_pytorch
|
||||
optimized_var_attention = var_attention_optimized_split
|
||||
optimized_attention = attention_basic
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
logging.info("Using sage attention")
|
||||
optimized_attention = attention_sage
|
||||
if SAGE_ATTENTION3_IS_AVAILABLE and _use_blackwell_attention():
|
||||
logging.info("Using SageAttention3 for variable-length attention")
|
||||
optimized_var_attention = var_attention_sage3
|
||||
elif SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||
logging.info("Using SageAttention for variable-length attention")
|
||||
optimized_var_attention = var_attention_sage
|
||||
else:
|
||||
logging.info("Using pytorch attention for variable-length attention")
|
||||
optimized_var_attention = var_attention_pytorch
|
||||
elif model_management.flash_attention_enabled():
|
||||
logging.info("Using Flash Attention")
|
||||
optimized_attention = attention_flash
|
||||
if FLASH_ATTENTION_VARLEN_IS_AVAILABLE and model_management.get_torch_device().type == "cuda":
|
||||
logging.info("Using Flash Attention 2 for variable-length attention")
|
||||
optimized_var_attention = var_attention_flash
|
||||
else:
|
||||
logging.info("Using pytorch attention for variable-length attention")
|
||||
optimized_var_attention = var_attention_pytorch
|
||||
elif model_management.xformers_enabled():
|
||||
logging.info("Using xformers attention")
|
||||
optimized_attention = attention_xformers
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch attention")
|
||||
optimized_attention = attention_pytorch
|
||||
optimized_var_attention = var_attention_pytorch
|
||||
else:
|
||||
if args.use_split_cross_attention:
|
||||
logging.info("Using split optimization for attention")
|
||||
optimized_attention = attention_split
|
||||
optimized_var_attention = var_attention_split
|
||||
else:
|
||||
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
optimized_attention = attention_sub_quad
|
||||
optimized_var_attention = var_attention_sub_quad
|
||||
|
||||
logging.info("Using optimized_attention split-loop for variable-length attention")
|
||||
|
||||
optimized_attention_masked = optimized_attention
|
||||
|
||||
@ -1193,25 +845,16 @@ optimized_attention_masked = optimized_attention
|
||||
# register core-supported attention functions
|
||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("sage", attention_sage)
|
||||
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||
register_attention_function("var_attention_sage", var_attention_sage)
|
||||
if SAGE_ATTENTION3_IS_AVAILABLE:
|
||||
register_attention_function("sage3", attention3_sage)
|
||||
register_attention_function("var_attention_sage3", var_attention_sage3)
|
||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("flash", attention_flash)
|
||||
if FLASH_ATTENTION_VARLEN_IS_AVAILABLE:
|
||||
register_attention_function("var_attention_flash", var_attention_flash)
|
||||
if FLASH_ATTENTION3_IS_AVAILABLE:
|
||||
register_attention_function("var_attention_flash3", var_attention_flash3)
|
||||
if model_management.xformers_enabled():
|
||||
register_attention_function("xformers", attention_xformers)
|
||||
register_attention_function("pytorch", attention_pytorch)
|
||||
register_attention_function("var_attention_pytorch", var_attention_pytorch)
|
||||
register_attention_function("sub_quad", attention_sub_quad)
|
||||
register_attention_function("var_attention_sub_quad", var_attention_sub_quad)
|
||||
register_attention_function("split", attention_split)
|
||||
register_attention_function("var_attention_split", var_attention_split)
|
||||
register_attention_function("var_attention_optimized_split", var_attention_optimized_split)
|
||||
|
||||
|
||||
def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||
|
||||
@ -40,7 +40,7 @@ SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path
|
||||
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path.
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# C. ByteDance config / source (BYTEDANCE - cite myseedvr2/SeedVR)
|
||||
# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR)
|
||||
# --------------------------------------------------------------------------------------
|
||||
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm.
|
||||
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.
|
||||
|
||||
@ -58,10 +58,11 @@ class CustomRMSNorm(nn.Module):
|
||||
|
||||
dims = tuple(range(-len(self.normalized_shape), 0))
|
||||
|
||||
variance = input.pow(2).mean(dim=dims, keepdim=True)
|
||||
normalized = input.float()
|
||||
variance = normalized.pow(2).mean(dim=dims, keepdim=True)
|
||||
rms = torch.sqrt(variance + self.eps)
|
||||
|
||||
normalized = input / rms
|
||||
normalized = normalized / rms
|
||||
|
||||
if self.elementwise_affine:
|
||||
return normalized * self.weight.to(input.dtype)
|
||||
@ -472,8 +473,8 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d):
|
||||
freqs = freqs.to(device=q.device)
|
||||
q = rearrange(q, "L h d -> h L d")
|
||||
k = rearrange(k, "L h d -> h L d")
|
||||
q = _apply_rope1_partial(q, freqs)
|
||||
k = _apply_rope1_partial(k, freqs)
|
||||
q = _apply_seedvr2_rotary_emb(freqs, q.float()).to(q.dtype)
|
||||
k = _apply_seedvr2_rotary_emb(freqs, k.float()).to(k.dtype)
|
||||
q = rearrange(q, "h L d -> L h d")
|
||||
k = rearrange(k, "h L d -> L h d")
|
||||
return q, k
|
||||
@ -483,11 +484,20 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d):
|
||||
self,
|
||||
shape: torch.LongTensor,
|
||||
) -> torch.Tensor:
|
||||
# Primary provenance: ByteDance-Seed/SeedVR models/dit/rope.py builds
|
||||
# 7B pixel RoPE with the interleaved-angle convention, not Comfy's
|
||||
# Flux freqs_cis matrix.
|
||||
plain_rope = RotaryEmbedding(
|
||||
dim=self.rope.freqs.numel() * 2,
|
||||
freqs_for="pixel",
|
||||
max_freq=BYTEDANCE_ROPE_MAX_FREQ,
|
||||
)
|
||||
plain_rope = plain_rope.to(self.rope.dummy.device)
|
||||
freq_list = []
|
||||
for f, h, w in shape.tolist():
|
||||
freqs = self.get_axial_freqs(f, h, w)
|
||||
freqs = plain_rope.get_axial_freqs(f, h, w)
|
||||
freq_list.append(freqs.view(-1, freqs.size(-1)))
|
||||
return _to_flux_freqs_cis(torch.cat(freq_list, dim=0))
|
||||
return torch.cat(freq_list, dim=0)
|
||||
|
||||
|
||||
class MMRotaryEmbeddingBase(RotaryEmbeddingBase):
|
||||
@ -556,6 +566,36 @@ def apply_rotary_emb(
|
||||
out = torch.cat((t_left, t_middle_out, t_right), dim=-1)
|
||||
return out.type(dtype)
|
||||
|
||||
|
||||
def _apply_seedvr2_rotary_emb(
|
||||
freqs: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
start_index: int = 0,
|
||||
scale: float = 1.0,
|
||||
seq_dim: int = -2,
|
||||
freqs_seq_dim: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
dtype = t.dtype
|
||||
if freqs_seq_dim is None and (freqs.ndim == 2 or t.ndim == 3):
|
||||
freqs_seq_dim = 0
|
||||
|
||||
if t.ndim == 3 or freqs_seq_dim is not None:
|
||||
seq_len = t.shape[seq_dim]
|
||||
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim)
|
||||
|
||||
rot_feats = freqs.shape[-1]
|
||||
end_index = start_index + rot_feats
|
||||
|
||||
t_left = t[..., :start_index]
|
||||
t_middle = t[..., start_index:end_index]
|
||||
t_right = t[..., end_index:]
|
||||
|
||||
freqs = freqs.to(device=t_middle.device, dtype=t_middle.dtype)
|
||||
cos = freqs.cos() * scale
|
||||
sin = freqs.sin() * scale
|
||||
t_middle = (t_middle * cos) + (rotate_half(t_middle) * sin)
|
||||
return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype)
|
||||
|
||||
def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`."""
|
||||
angles = freqs_interleaved[..., ::2].float()
|
||||
@ -1380,6 +1420,8 @@ class NaDiT(nn.Module):
|
||||
**kwargs,
|
||||
):
|
||||
self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM
|
||||
if self._7b_version:
|
||||
rope_type = "rope3d"
|
||||
self.dtype = dtype
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
|
||||
|
||||
@ -24,12 +24,7 @@ from comfy.ldm.seedvr.constants import (
|
||||
BYTEDANCE_VAE_SHIFTING_FACTOR,
|
||||
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE,
|
||||
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE,
|
||||
CIELAB_DELTA,
|
||||
CIELAB_KAPPA,
|
||||
D65_WHITE_X,
|
||||
D65_WHITE_Z,
|
||||
SEEDVR2_LATENT_CHANNELS,
|
||||
WAVELET_DECOMP_LEVELS,
|
||||
)
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||
|
||||
@ -575,7 +575,7 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
description="Build SeedVR2 positive/negative conditioning from a VAE latent.",
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The SeedVR2 model."),
|
||||
io.Latent.Input("vae_conditioning", tooltip="The VAE-encoded latent to condition on."),
|
||||
io.Latent.Input("vae_conditioning", display_name="LATENT", tooltip="The VAE-encoded latent to condition on."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name = "model"),
|
||||
|
||||
@ -6,9 +6,7 @@ Sources (all merged verbatim, helper names disambiguated where colliding):
|
||||
apply_rotary_emb wrapper oracle at fp32.
|
||||
* GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare
|
||||
memory_occupy against get_norm_limit(), not float('inf').
|
||||
* var_attention backend registry.
|
||||
* var_attention_pytorch SeedVR2-named guard — present-API shape contract
|
||||
with AST-level pinning of the guard ordering.
|
||||
* SeedVR2 variable-length attention split-loop contract.
|
||||
|
||||
Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and
|
||||
comfy.ldm.modules.attention transitively pull in comfy.model_management,
|
||||
@ -18,11 +16,6 @@ set first.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import logging
|
||||
import textwrap
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -45,7 +38,7 @@ from comfy.ldm.seedvr.vae import ( # noqa: E402
|
||||
causal_norm_wrapper,
|
||||
set_norm_limit,
|
||||
)
|
||||
from comfy.ldm.modules.attention import var_attention_pytorch # noqa: E402
|
||||
from comfy.ldm.modules.attention import var_attention_optimized_split # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -215,13 +208,14 @@ def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# var_attention backend tests (test_seedvr_var_attention_backends.py)
|
||||
# SeedVR2 var_attention split-loop tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_var_attention_registry_contains_always_available_entries():
|
||||
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_pytorch"] is attention.var_attention_pytorch
|
||||
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_sub_quad"] is attention.var_attention_sub_quad
|
||||
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_split"] is attention.var_attention_split
|
||||
assert (
|
||||
attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_optimized_split"]
|
||||
is attention.var_attention_optimized_split
|
||||
)
|
||||
|
||||
|
||||
def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch):
|
||||
@ -285,105 +279,63 @@ def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypa
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# var_attention_pytorch SeedVR2 guard tests
|
||||
# (test_var_attention_pytorch_seedvr2_guard.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch):
|
||||
heads = 2
|
||||
head_dim = 3
|
||||
q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim)
|
||||
k = q + 100
|
||||
v = q + 200
|
||||
cu = torch.tensor([0, 2, 5], dtype=torch.int32)
|
||||
calls = []
|
||||
|
||||
def _pytorch_guard_inputs():
|
||||
heads, head_dim, total_tokens = 2, 8, 6
|
||||
embed_dim = heads * head_dim
|
||||
q = torch.randn(total_tokens, embed_dim)
|
||||
k = torch.randn(total_tokens, embed_dim)
|
||||
v = torch.randn(total_tokens, embed_dim)
|
||||
cu = torch.tensor([0, 3, 6], dtype=torch.int32)
|
||||
return q, k, v, heads, cu, cu, total_tokens, embed_dim
|
||||
def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs):
|
||||
calls.append(
|
||||
{
|
||||
"q_shape": tuple(q_arg.shape),
|
||||
"k_shape": tuple(k_arg.shape),
|
||||
"v_shape": tuple(v_arg.shape),
|
||||
"heads": heads_arg,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
)
|
||||
return q_arg + v_arg
|
||||
|
||||
monkeypatch.setattr(attention, "optimized_attention", fake_optimized_attention)
|
||||
|
||||
def _assert_guard_source_pin():
|
||||
src = textwrap.dedent(inspect.getsource(var_attention_pytorch))
|
||||
tree = ast.parse(src)
|
||||
raise_lines = []
|
||||
nested_lines = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Raise) and isinstance(node.exc, ast.Call):
|
||||
func = node.exc.func
|
||||
if isinstance(func, ast.Name) and func.id == "RuntimeError":
|
||||
raise_lines.append(node.lineno)
|
||||
if isinstance(node, ast.Attribute) and node.attr == "nested_tensor_from_jagged":
|
||||
nested_lines.append(node.lineno)
|
||||
assert raise_lines, (
|
||||
"var_attention_pytorch has no `raise RuntimeError(...)` AST node; "
|
||||
f"the SeedVR2-named guard is missing.\n--- source ---\n{src}"
|
||||
)
|
||||
assert nested_lines, (
|
||||
"var_attention_pytorch source has no `nested_tensor_from_jagged` "
|
||||
f"attribute access; cannot pin guard ordering.\n"
|
||||
f"--- source ---\n{src}"
|
||||
)
|
||||
first_raise = min(raise_lines)
|
||||
first_nested = min(nested_lines)
|
||||
assert first_raise < first_nested, (
|
||||
f"`raise RuntimeError(...)` first appears at line {first_raise}, "
|
||||
f"but `torch.nested.nested_tensor_from_jagged` is referenced first "
|
||||
f"at line {first_nested}; the guard must precede the lookup.\n"
|
||||
f"--- source ---\n{src}"
|
||||
out = var_attention_optimized_split(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu,
|
||||
cu,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=True,
|
||||
)
|
||||
|
||||
|
||||
def test_missing_api_raises_seedvr2_runtime_error(monkeypatch):
|
||||
monkeypatch.delattr(torch.nested, "nested_tensor_from_jagged", raising=False)
|
||||
q, k, v, heads, cu_q, cu_k, _, _ = _pytorch_guard_inputs()
|
||||
|
||||
with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"):
|
||||
var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
|
||||
|
||||
_assert_guard_source_pin()
|
||||
assert tuple(out.shape) == (5, heads, head_dim)
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["q_shape"] == (1, heads, 2, head_dim)
|
||||
assert calls[1]["q_shape"] == (1, heads, 3, head_dim)
|
||||
assert all(call["heads"] == heads for call in calls)
|
||||
assert all(call["kwargs"]["skip_reshape"] is True for call in calls)
|
||||
assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls)
|
||||
torch.testing.assert_close(out, q + v, rtol=0, atol=0)
|
||||
|
||||
|
||||
def test_missing_namespace_raises_seedvr2_runtime_error(monkeypatch):
|
||||
monkeypatch.delattr(torch, "nested", raising=False)
|
||||
q, k, v, heads, cu_q, cu_k, _, _ = _pytorch_guard_inputs()
|
||||
def test_var_attention_optimized_split_rejects_bad_offsets():
|
||||
q = torch.randn(5, 2, 3)
|
||||
cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32)
|
||||
cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32)
|
||||
|
||||
with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"):
|
||||
var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
|
||||
|
||||
_assert_guard_source_pin()
|
||||
|
||||
|
||||
def test_present_api_returns_expected_shape():
|
||||
q, k, v, heads, cu_q, cu_k, total_tokens, embed_dim = _pytorch_guard_inputs()
|
||||
|
||||
torch_fx_logger = logging.getLogger("torch.fx._symbolic_trace")
|
||||
old_torch_fx_level = torch_fx_logger.level
|
||||
torch_fx_logger.setLevel(logging.ERROR)
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="The PyTorch API of nested tensors is in prototype stage.*",
|
||||
category=UserWarning,
|
||||
)
|
||||
out = var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
|
||||
finally:
|
||||
torch_fx_logger.setLevel(old_torch_fx_level)
|
||||
|
||||
assert tuple(out.shape) == (total_tokens, embed_dim), (
|
||||
f"expected ({total_tokens}, {embed_dim}); got {tuple(out.shape)}"
|
||||
)
|
||||
|
||||
_assert_guard_source_pin()
|
||||
|
||||
|
||||
def test_malformed_offsets_propagates_torch_runtime_error():
|
||||
q, k, v, heads, _, _, _, _ = _pytorch_guard_inputs()
|
||||
cu_q_bad = torch.tensor([0, 3, 7], dtype=torch.int32)
|
||||
cu_k_ok = torch.tensor([0, 3, 6], dtype=torch.int32)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
var_attention_pytorch(q, k, v, heads, cu_q_bad, cu_k_ok)
|
||||
|
||||
msg = str(exc_info.value)
|
||||
assert "SeedVR2" not in msg
|
||||
|
||||
_assert_guard_source_pin()
|
||||
with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"):
|
||||
var_attention_optimized_split(
|
||||
q,
|
||||
q,
|
||||
q,
|
||||
2,
|
||||
cu_bad,
|
||||
cu_ok,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=True,
|
||||
)
|
||||
|
||||
@ -213,11 +213,11 @@ def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
|
||||
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
|
||||
freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1)
|
||||
|
||||
expected_q = seedvr_model.apply_rotary_emb(
|
||||
expected_q = seedvr_model._apply_seedvr2_rotary_emb(
|
||||
freqs,
|
||||
q.permute(1, 0, 2).float(),
|
||||
).to(q.dtype).permute(1, 0, 2)
|
||||
expected_k = seedvr_model.apply_rotary_emb(
|
||||
expected_k = seedvr_model._apply_seedvr2_rotary_emb(
|
||||
freqs,
|
||||
k.permute(1, 0, 2).float(),
|
||||
).to(k.dtype).permute(1, 0, 2)
|
||||
|
||||
@ -174,10 +174,7 @@ def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
|
||||
captured.update(kwargs)
|
||||
return torch.zeros(1, 3, 1, 16, 16)
|
||||
|
||||
with (
|
||||
patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae),
|
||||
patch.object(vae_mod, "lab_color_transfer", side_effect=lambda content, style: content),
|
||||
):
|
||||
with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae):
|
||||
wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling)
|
||||
|
||||
assert captured["temporal_overlap"] == 7
|
||||
|
||||
Loading…
Reference in New Issue
Block a user