Replace SeedVR2 custom varlen attention backends and fix 7B RoPE

This commit is contained in:
John Pollock 2026-06-02 21:12:34 -05:00
parent 22078c799b
commit 529b9232f0
8 changed files with 125 additions and 496 deletions

View File

@ -32,14 +32,6 @@ except ImportError as e:
raise e
exit(-1)
SAGE_ATTENTION_VARLEN_IS_AVAILABLE = False
try:
from sageattention import sageattn_varlen
SAGE_ATTENTION_VARLEN_IS_AVAILABLE = True
except ImportError:
if model_management.sage_attention_enabled():
logging.warning("SageAttention variable-length attention is unavailable, using pytorch var-len attention instead.")
SAGE_ATTENTION3_IS_AVAILABLE = False
try:
from sageattn3 import sageattn3_blackwell
@ -48,7 +40,6 @@ except ImportError:
pass
FLASH_ATTENTION_IS_AVAILABLE = False
FLASH_ATTENTION_VARLEN_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
FLASH_ATTENTION_IS_AVAILABLE = True
@ -57,20 +48,6 @@ except ImportError:
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
try:
from flash_attn import flash_attn_varlen_func
FLASH_ATTENTION_VARLEN_IS_AVAILABLE = True
except ImportError:
if model_management.flash_attention_enabled() and FLASH_ATTENTION_IS_AVAILABLE:
logging.warning("Flash Attention variable-length attention is unavailable, using pytorch var-len attention instead.")
FLASH_ATTENTION3_IS_AVAILABLE = False
try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn3_varlen_func
FLASH_ATTENTION3_IS_AVAILABLE = True
except ImportError:
pass
REGISTERED_ATTENTION_FUNCTIONS = {}
def register_attention_function(name: str, func: Callable):
# avoid replacing existing functions
@ -758,18 +735,6 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
)
return out
_VAR_ATTENTION_NESTED_API_NAME = "nested_tensor_from_jagged"
_VAR_ATTENTION_GUARD_MESSAGE = (
"SeedVR2 var_attention_pytorch: torch.nested.nested_tensor_from_jagged "
"is required by this attention path; the installed PyTorch build "
"does not provide it"
)
def _var_attention_max_seqlen(cu_seqlens):
return int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item())
def _var_attention_qkv(q, k, v, heads, skip_reshape):
if skip_reshape:
return q, k, v, q.shape[-1]
@ -797,34 +762,6 @@ def _use_blackwell_attention():
return (major, minor) >= (12, 0)
def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
_nested = getattr(torch, "nested", None)
if _nested is None or not hasattr(_nested, _VAR_ATTENTION_NESTED_API_NAME):
raise RuntimeError(_VAR_ATTENTION_GUARD_MESSAGE)
if not skip_reshape:
# assumes 2D q, k,v [total_tokens, embed_dim]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
q = q.view(total_tokens, heads, head_dim)
k = k.view(k.shape[0], heads, head_dim)
v = v.view(v.shape[0], heads, head_dim)
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(1, 2)
if not skip_output_reshape:
return out.values().reshape(-1, heads * (q.shape[-1]))
return out.values()
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
if cu_seqlens.dtype not in (torch.int32, torch.int64):
raise ValueError(f"{name} must use an integer dtype")
@ -842,7 +779,7 @@ def _split_indices(cu_seqlens):
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
def var_attention_pytorch_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
@ -863,329 +800,44 @@ def var_attention_pytorch_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
out_i = comfy.ops.scaled_dot_product_attention(q_i, k_i, v_i, attn_mask=None, dropout_p=0.0, is_causal=False)
out_dtype = q_i.dtype
if optimized_attention is attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
q_i = q_i.to(torch.bfloat16)
k_i = k_i.to(torch.bfloat16)
v_i = v_i.to(torch.bfloat16)
out_i = optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
if out_i.dtype != out_dtype:
out_i = out_i.to(out_dtype)
out.append(out_i.squeeze(0).permute(1, 0, 2))
out = torch.cat(out, dim=0)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
@torch._dynamo.disable
def var_attention_sage(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
if not SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
out_dtype = q.dtype
if not (q.dtype == k.dtype == v.dtype):
k = k.to(q.dtype)
v = v.to(q.dtype)
fallback_q, fallback_k, fallback_v = q, k, v
if q.dtype not in (torch.float16, torch.bfloat16):
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)
v = v.to(torch.bfloat16)
sm_scale = kwargs.get("softmax_scale")
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(head_dim)
try:
out = sageattn_varlen(
q,
k,
v,
cu_seqlens_q.int(),
cu_seqlens_k.int(),
_var_attention_max_seqlen(cu_seqlens_q),
_var_attention_max_seqlen(cu_seqlens_k),
kwargs.get("causal", False),
sm_scale,
)
except Exception as e:
logging.error("Error running sage var-len attention: %s, using pytorch var-len attention instead.", e)
out = var_attention_pytorch(
fallback_q,
fallback_k,
fallback_v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
)
if out.dtype != out_dtype:
out = out.to(out_dtype)
return out
if out.dtype != out_dtype:
out = out.to(out_dtype)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
@torch._dynamo.disable
def var_attention_sage3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
if not SAGE_ATTENTION3_IS_AVAILABLE:
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
return var_attention_sage(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs,
)
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
seq_lens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
seq_lens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
uniform_q = bool((seq_lens_q == seq_lens_q[0]).all().item())
uniform_k = bool((seq_lens_k == seq_lens_k[0]).all().item())
if not (uniform_q and uniform_k and seq_lens_q[0] == seq_lens_k[0]):
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
return var_attention_sage(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
**kwargs,
)
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
)
out_dtype = q.dtype
if not (q.dtype == k.dtype == v.dtype):
k = k.to(q.dtype)
v = v.to(q.dtype)
fallback_q, fallback_k, fallback_v = q, k, v
if q.dtype not in (torch.float16, torch.bfloat16):
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)
v = v.to(torch.bfloat16)
batch_size = len(cu_seqlens_q) - 1
seq_len = int(seq_lens_q[0].item())
q = q.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
try:
out = sageattn3_blackwell(q, k, v, is_causal=kwargs.get("causal", False))
except Exception as e:
logging.error("Error running SageAttention3 var-len attention: %s, using fallback var-len attention instead.", e)
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
return var_attention_sage(
fallback_q,
fallback_k,
fallback_v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
**kwargs,
)
return var_attention_pytorch(
fallback_q,
fallback_k,
fallback_v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
)
out = out.transpose(1, 2).reshape(-1, heads, head_dim).contiguous()
if out.dtype != out_dtype:
out = out.to(out_dtype)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
@torch._dynamo.disable
def var_attention_flash(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
if not FLASH_ATTENTION_VARLEN_IS_AVAILABLE:
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q)
max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k)
try:
out = flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q.int(),
cu_seqlens_k=cu_seqlens_k.int(),
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=kwargs.get("dropout_p", 0.0),
causal=kwargs.get("causal", False),
deterministic=torch.are_deterministic_algorithms_enabled(),
)
except Exception as e:
logging.error("Error running Flash Attention var-len attention: %s, using pytorch var-len attention instead.", e)
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
@torch._dynamo.disable
def var_attention_flash3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
if not FLASH_ATTENTION3_IS_AVAILABLE:
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q)
max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k)
try:
out = flash_attn3_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q.int(),
cu_seqlens_k=cu_seqlens_k.int(),
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
seqused_q=None,
seqused_k=None,
softmax_scale=kwargs.get("softmax_scale"),
causal=kwargs.get("causal", False),
deterministic=torch.are_deterministic_algorithms_enabled(),
)
except Exception as e:
logging.error("Error running Flash Attention 3 var-len attention: %s, using pytorch var-len attention instead.", e)
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
)
if isinstance(out, tuple):
out = out[0]
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
@torch._dynamo.disable
def var_attention_sub_quad(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
@torch._dynamo.disable
def var_attention_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
return var_attention_pytorch_split(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
optimized_var_attention = var_attention_pytorch
optimized_var_attention = var_attention_optimized_split
optimized_attention = attention_basic
if model_management.sage_attention_enabled():
logging.info("Using sage attention")
optimized_attention = attention_sage
if SAGE_ATTENTION3_IS_AVAILABLE and _use_blackwell_attention():
logging.info("Using SageAttention3 for variable-length attention")
optimized_var_attention = var_attention_sage3
elif SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
logging.info("Using SageAttention for variable-length attention")
optimized_var_attention = var_attention_sage
else:
logging.info("Using pytorch attention for variable-length attention")
optimized_var_attention = var_attention_pytorch
elif model_management.flash_attention_enabled():
logging.info("Using Flash Attention")
optimized_attention = attention_flash
if FLASH_ATTENTION_VARLEN_IS_AVAILABLE and model_management.get_torch_device().type == "cuda":
logging.info("Using Flash Attention 2 for variable-length attention")
optimized_var_attention = var_attention_flash
else:
logging.info("Using pytorch attention for variable-length attention")
optimized_var_attention = var_attention_pytorch
elif model_management.xformers_enabled():
logging.info("Using xformers attention")
optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention")
optimized_attention = attention_pytorch
optimized_var_attention = var_attention_pytorch
else:
if args.use_split_cross_attention:
logging.info("Using split optimization for attention")
optimized_attention = attention_split
optimized_var_attention = var_attention_split
else:
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
optimized_var_attention = var_attention_sub_quad
logging.info("Using optimized_attention split-loop for variable-length attention")
optimized_attention_masked = optimized_attention
@ -1193,25 +845,16 @@ optimized_attention_masked = optimized_attention
# register core-supported attention functions
if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage)
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
register_attention_function("var_attention_sage", var_attention_sage)
if SAGE_ATTENTION3_IS_AVAILABLE:
register_attention_function("sage3", attention3_sage)
register_attention_function("var_attention_sage3", var_attention_sage3)
if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash)
if FLASH_ATTENTION_VARLEN_IS_AVAILABLE:
register_attention_function("var_attention_flash", var_attention_flash)
if FLASH_ATTENTION3_IS_AVAILABLE:
register_attention_function("var_attention_flash3", var_attention_flash3)
if model_management.xformers_enabled():
register_attention_function("xformers", attention_xformers)
register_attention_function("pytorch", attention_pytorch)
register_attention_function("var_attention_pytorch", var_attention_pytorch)
register_attention_function("sub_quad", attention_sub_quad)
register_attention_function("var_attention_sub_quad", var_attention_sub_quad)
register_attention_function("split", attention_split)
register_attention_function("var_attention_split", var_attention_split)
register_attention_function("var_attention_optimized_split", var_attention_optimized_split)
def optimized_attention_for_device(device, mask=False, small_input=False):

View File

@ -40,7 +40,7 @@ SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path.
# --------------------------------------------------------------------------------------
# C. ByteDance config / source (BYTEDANCE - cite myseedvr2/SeedVR)
# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR)
# --------------------------------------------------------------------------------------
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm.
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.

View File

@ -58,10 +58,11 @@ class CustomRMSNorm(nn.Module):
dims = tuple(range(-len(self.normalized_shape), 0))
variance = input.pow(2).mean(dim=dims, keepdim=True)
normalized = input.float()
variance = normalized.pow(2).mean(dim=dims, keepdim=True)
rms = torch.sqrt(variance + self.eps)
normalized = input / rms
normalized = normalized / rms
if self.elementwise_affine:
return normalized * self.weight.to(input.dtype)
@ -472,8 +473,8 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d):
freqs = freqs.to(device=q.device)
q = rearrange(q, "L h d -> h L d")
k = rearrange(k, "L h d -> h L d")
q = _apply_rope1_partial(q, freqs)
k = _apply_rope1_partial(k, freqs)
q = _apply_seedvr2_rotary_emb(freqs, q.float()).to(q.dtype)
k = _apply_seedvr2_rotary_emb(freqs, k.float()).to(k.dtype)
q = rearrange(q, "h L d -> L h d")
k = rearrange(k, "h L d -> L h d")
return q, k
@ -483,11 +484,20 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d):
self,
shape: torch.LongTensor,
) -> torch.Tensor:
# Primary provenance: ByteDance-Seed/SeedVR models/dit/rope.py builds
# 7B pixel RoPE with the interleaved-angle convention, not Comfy's
# Flux freqs_cis matrix.
plain_rope = RotaryEmbedding(
dim=self.rope.freqs.numel() * 2,
freqs_for="pixel",
max_freq=BYTEDANCE_ROPE_MAX_FREQ,
)
plain_rope = plain_rope.to(self.rope.dummy.device)
freq_list = []
for f, h, w in shape.tolist():
freqs = self.get_axial_freqs(f, h, w)
freqs = plain_rope.get_axial_freqs(f, h, w)
freq_list.append(freqs.view(-1, freqs.size(-1)))
return _to_flux_freqs_cis(torch.cat(freq_list, dim=0))
return torch.cat(freq_list, dim=0)
class MMRotaryEmbeddingBase(RotaryEmbeddingBase):
@ -556,6 +566,36 @@ def apply_rotary_emb(
out = torch.cat((t_left, t_middle_out, t_right), dim=-1)
return out.type(dtype)
def _apply_seedvr2_rotary_emb(
freqs: torch.Tensor,
t: torch.Tensor,
start_index: int = 0,
scale: float = 1.0,
seq_dim: int = -2,
freqs_seq_dim: int | None = None,
) -> torch.Tensor:
dtype = t.dtype
if freqs_seq_dim is None and (freqs.ndim == 2 or t.ndim == 3):
freqs_seq_dim = 0
if t.ndim == 3 or freqs_seq_dim is not None:
seq_len = t.shape[seq_dim]
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim)
rot_feats = freqs.shape[-1]
end_index = start_index + rot_feats
t_left = t[..., :start_index]
t_middle = t[..., start_index:end_index]
t_right = t[..., end_index:]
freqs = freqs.to(device=t_middle.device, dtype=t_middle.dtype)
cos = freqs.cos() * scale
sin = freqs.sin() * scale
t_middle = (t_middle * cos) + (rotate_half(t_middle) * sin)
return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype)
def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
"""Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`."""
angles = freqs_interleaved[..., ::2].float()
@ -1380,6 +1420,8 @@ class NaDiT(nn.Module):
**kwargs,
):
self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM
if self._7b_version:
rope_type = "rope3d"
self.dtype = dtype
factory_kwargs = {"device": device, "dtype": dtype}
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]

View File

@ -24,12 +24,7 @@ from comfy.ldm.seedvr.constants import (
BYTEDANCE_VAE_SHIFTING_FACTOR,
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE,
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE,
CIELAB_DELTA,
CIELAB_KAPPA,
D65_WHITE_X,
D65_WHITE_Z,
SEEDVR2_LATENT_CHANNELS,
WAVELET_DECOMP_LEVELS,
)
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.model import vae_attention

View File

@ -575,7 +575,7 @@ class SeedVR2Conditioning(io.ComfyNode):
description="Build SeedVR2 positive/negative conditioning from a VAE latent.",
inputs=[
io.Model.Input("model", tooltip="The SeedVR2 model."),
io.Latent.Input("vae_conditioning", tooltip="The VAE-encoded latent to condition on."),
io.Latent.Input("vae_conditioning", display_name="LATENT", tooltip="The VAE-encoded latent to condition on."),
],
outputs=[
io.Model.Output(display_name = "model"),

View File

@ -6,9 +6,7 @@ Sources (all merged verbatim, helper names disambiguated where colliding):
apply_rotary_emb wrapper oracle at fp32.
* GroupNorm limit gate causal_norm_wrapper at vae.py:509 must compare
memory_occupy against get_norm_limit(), not float('inf').
* var_attention backend registry.
* var_attention_pytorch SeedVR2-named guard present-API shape contract
with AST-level pinning of the guard ordering.
* SeedVR2 variable-length attention split-loop contract.
Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and
comfy.ldm.modules.attention transitively pull in comfy.model_management,
@ -18,11 +16,6 @@ set first.
from __future__ import annotations
import ast
import inspect
import logging
import textwrap
import warnings
from unittest.mock import patch
import pytest
@ -45,7 +38,7 @@ from comfy.ldm.seedvr.vae import ( # noqa: E402
causal_norm_wrapper,
set_norm_limit,
)
from comfy.ldm.modules.attention import var_attention_pytorch # noqa: E402
from comfy.ldm.modules.attention import var_attention_optimized_split # noqa: E402
# ---------------------------------------------------------------------------
@ -215,13 +208,14 @@ def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
# ---------------------------------------------------------------------------
# var_attention backend tests (test_seedvr_var_attention_backends.py)
# SeedVR2 var_attention split-loop tests
# ---------------------------------------------------------------------------
def test_var_attention_registry_contains_always_available_entries():
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_pytorch"] is attention.var_attention_pytorch
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_sub_quad"] is attention.var_attention_sub_quad
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_split"] is attention.var_attention_split
assert (
attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_optimized_split"]
is attention.var_attention_optimized_split
)
def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch):
@ -285,105 +279,63 @@ def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypa
)
# ---------------------------------------------------------------------------
# var_attention_pytorch SeedVR2 guard tests
# (test_var_attention_pytorch_seedvr2_guard.py)
# ---------------------------------------------------------------------------
def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch):
heads = 2
head_dim = 3
q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim)
k = q + 100
v = q + 200
cu = torch.tensor([0, 2, 5], dtype=torch.int32)
calls = []
def _pytorch_guard_inputs():
heads, head_dim, total_tokens = 2, 8, 6
embed_dim = heads * head_dim
q = torch.randn(total_tokens, embed_dim)
k = torch.randn(total_tokens, embed_dim)
v = torch.randn(total_tokens, embed_dim)
cu = torch.tensor([0, 3, 6], dtype=torch.int32)
return q, k, v, heads, cu, cu, total_tokens, embed_dim
def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs):
calls.append(
{
"q_shape": tuple(q_arg.shape),
"k_shape": tuple(k_arg.shape),
"v_shape": tuple(v_arg.shape),
"heads": heads_arg,
"kwargs": kwargs,
}
)
return q_arg + v_arg
monkeypatch.setattr(attention, "optimized_attention", fake_optimized_attention)
def _assert_guard_source_pin():
src = textwrap.dedent(inspect.getsource(var_attention_pytorch))
tree = ast.parse(src)
raise_lines = []
nested_lines = []
for node in ast.walk(tree):
if isinstance(node, ast.Raise) and isinstance(node.exc, ast.Call):
func = node.exc.func
if isinstance(func, ast.Name) and func.id == "RuntimeError":
raise_lines.append(node.lineno)
if isinstance(node, ast.Attribute) and node.attr == "nested_tensor_from_jagged":
nested_lines.append(node.lineno)
assert raise_lines, (
"var_attention_pytorch has no `raise RuntimeError(...)` AST node; "
f"the SeedVR2-named guard is missing.\n--- source ---\n{src}"
)
assert nested_lines, (
"var_attention_pytorch source has no `nested_tensor_from_jagged` "
f"attribute access; cannot pin guard ordering.\n"
f"--- source ---\n{src}"
)
first_raise = min(raise_lines)
first_nested = min(nested_lines)
assert first_raise < first_nested, (
f"`raise RuntimeError(...)` first appears at line {first_raise}, "
f"but `torch.nested.nested_tensor_from_jagged` is referenced first "
f"at line {first_nested}; the guard must precede the lookup.\n"
f"--- source ---\n{src}"
out = var_attention_optimized_split(
q,
k,
v,
heads,
cu,
cu,
skip_reshape=True,
skip_output_reshape=True,
)
def test_missing_api_raises_seedvr2_runtime_error(monkeypatch):
monkeypatch.delattr(torch.nested, "nested_tensor_from_jagged", raising=False)
q, k, v, heads, cu_q, cu_k, _, _ = _pytorch_guard_inputs()
with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"):
var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
_assert_guard_source_pin()
assert tuple(out.shape) == (5, heads, head_dim)
assert len(calls) == 2
assert calls[0]["q_shape"] == (1, heads, 2, head_dim)
assert calls[1]["q_shape"] == (1, heads, 3, head_dim)
assert all(call["heads"] == heads for call in calls)
assert all(call["kwargs"]["skip_reshape"] is True for call in calls)
assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls)
torch.testing.assert_close(out, q + v, rtol=0, atol=0)
def test_missing_namespace_raises_seedvr2_runtime_error(monkeypatch):
monkeypatch.delattr(torch, "nested", raising=False)
q, k, v, heads, cu_q, cu_k, _, _ = _pytorch_guard_inputs()
def test_var_attention_optimized_split_rejects_bad_offsets():
q = torch.randn(5, 2, 3)
cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32)
cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32)
with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"):
var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
_assert_guard_source_pin()
def test_present_api_returns_expected_shape():
q, k, v, heads, cu_q, cu_k, total_tokens, embed_dim = _pytorch_guard_inputs()
torch_fx_logger = logging.getLogger("torch.fx._symbolic_trace")
old_torch_fx_level = torch_fx_logger.level
torch_fx_logger.setLevel(logging.ERROR)
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The PyTorch API of nested tensors is in prototype stage.*",
category=UserWarning,
)
out = var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
finally:
torch_fx_logger.setLevel(old_torch_fx_level)
assert tuple(out.shape) == (total_tokens, embed_dim), (
f"expected ({total_tokens}, {embed_dim}); got {tuple(out.shape)}"
)
_assert_guard_source_pin()
def test_malformed_offsets_propagates_torch_runtime_error():
q, k, v, heads, _, _, _, _ = _pytorch_guard_inputs()
cu_q_bad = torch.tensor([0, 3, 7], dtype=torch.int32)
cu_k_ok = torch.tensor([0, 3, 6], dtype=torch.int32)
with pytest.raises(RuntimeError) as exc_info:
var_attention_pytorch(q, k, v, heads, cu_q_bad, cu_k_ok)
msg = str(exc_info.value)
assert "SeedVR2" not in msg
_assert_guard_source_pin()
with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"):
var_attention_optimized_split(
q,
q,
q,
2,
cu_bad,
cu_ok,
skip_reshape=True,
skip_output_reshape=True,
)

View File

@ -213,11 +213,11 @@ def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1)
expected_q = seedvr_model.apply_rotary_emb(
expected_q = seedvr_model._apply_seedvr2_rotary_emb(
freqs,
q.permute(1, 0, 2).float(),
).to(q.dtype).permute(1, 0, 2)
expected_k = seedvr_model.apply_rotary_emb(
expected_k = seedvr_model._apply_seedvr2_rotary_emb(
freqs,
k.permute(1, 0, 2).float(),
).to(k.dtype).permute(1, 0, 2)

View File

@ -174,10 +174,7 @@ def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
captured.update(kwargs)
return torch.zeros(1, 3, 1, 16, 16)
with (
patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae),
patch.object(vae_mod, "lab_color_transfer", side_effect=lambda content, style: content),
):
with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae):
wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling)
assert captured["temporal_overlap"] == 7