diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 8507557d5..b78e764c7 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -32,14 +32,6 @@ except ImportError as e: raise e exit(-1) -SAGE_ATTENTION_VARLEN_IS_AVAILABLE = False -try: - from sageattention import sageattn_varlen - SAGE_ATTENTION_VARLEN_IS_AVAILABLE = True -except ImportError: - if model_management.sage_attention_enabled(): - logging.warning("SageAttention variable-length attention is unavailable, using pytorch var-len attention instead.") - SAGE_ATTENTION3_IS_AVAILABLE = False try: from sageattn3 import sageattn3_blackwell @@ -48,7 +40,6 @@ except ImportError: pass FLASH_ATTENTION_IS_AVAILABLE = False -FLASH_ATTENTION_VARLEN_IS_AVAILABLE = False try: from flash_attn import flash_attn_func FLASH_ATTENTION_IS_AVAILABLE = True @@ -57,20 +48,6 @@ except ImportError: logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") exit(-1) -try: - from flash_attn import flash_attn_varlen_func - FLASH_ATTENTION_VARLEN_IS_AVAILABLE = True -except ImportError: - if model_management.flash_attention_enabled() and FLASH_ATTENTION_IS_AVAILABLE: - logging.warning("Flash Attention variable-length attention is unavailable, using pytorch var-len attention instead.") - -FLASH_ATTENTION3_IS_AVAILABLE = False -try: - from flash_attn_interface import flash_attn_varlen_func as flash_attn3_varlen_func - FLASH_ATTENTION3_IS_AVAILABLE = True -except ImportError: - pass - REGISTERED_ATTENTION_FUNCTIONS = {} def register_attention_function(name: str, func: Callable): # avoid replacing existing functions @@ -758,18 +735,6 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) return out -_VAR_ATTENTION_NESTED_API_NAME = "nested_tensor_from_jagged" -_VAR_ATTENTION_GUARD_MESSAGE = ( - "SeedVR2 var_attention_pytorch: torch.nested.nested_tensor_from_jagged " - "is required by this attention path; the installed PyTorch build " - "does not provide it" -) - - -def _var_attention_max_seqlen(cu_seqlens): - return int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item()) - - def _var_attention_qkv(q, k, v, heads, skip_reshape): if skip_reshape: return q, k, v, q.shape[-1] @@ -797,34 +762,6 @@ def _use_blackwell_attention(): return (major, minor) >= (12, 0) -def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False): - _nested = getattr(torch, "nested", None) - if _nested is None or not hasattr(_nested, _VAR_ATTENTION_NESTED_API_NAME): - raise RuntimeError(_VAR_ATTENTION_GUARD_MESSAGE) - - if not skip_reshape: - # assumes 2D q, k,v [total_tokens, embed_dim] - total_tokens, embed_dim = q.shape - head_dim = embed_dim // heads - q = q.view(total_tokens, heads, head_dim) - k = k.view(k.shape[0], heads, head_dim) - v = v.view(v.shape[0], heads, head_dim) - - q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long()) - k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long()) - v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long()) - - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - - out = out.transpose(1, 2) - if not skip_output_reshape: - return out.values().reshape(-1, heads * (q.shape[-1])) - return out.values() - - def _validate_split_cu_seqlens(name, cu_seqlens, token_count): if cu_seqlens.dtype not in (torch.int32, torch.int64): raise ValueError(f"{name} must use an integer dtype") @@ -842,7 +779,7 @@ def _split_indices(cu_seqlens): return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long) -def var_attention_pytorch_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False): +def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) _validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0]) @@ -863,329 +800,44 @@ def var_attention_pytorch_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip q_i = q_i.permute(1, 0, 2).unsqueeze(0) k_i = k_i.permute(1, 0, 2).unsqueeze(0) v_i = v_i.permute(1, 0, 2).unsqueeze(0) - out_i = comfy.ops.scaled_dot_product_attention(q_i, k_i, v_i, attn_mask=None, dropout_p=0.0, is_causal=False) + out_dtype = q_i.dtype + if optimized_attention is attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16): + q_i = q_i.to(torch.bfloat16) + k_i = k_i.to(torch.bfloat16) + v_i = v_i.to(torch.bfloat16) + out_i = optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True) + if out_i.dtype != out_dtype: + out_i = out_i.to(out_dtype) out.append(out_i.squeeze(0).permute(1, 0, 2)) out = torch.cat(out, dim=0) return _var_attention_output(out, heads, head_dim, skip_output_reshape) -@torch._dynamo.disable -def var_attention_sage(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): - if not SAGE_ATTENTION_VARLEN_IS_AVAILABLE: - return var_attention_pytorch( - q, - k, - v, - heads, - cu_seqlens_q, - cu_seqlens_k, - skip_reshape=skip_reshape, - skip_output_reshape=skip_output_reshape, - ) - q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) - out_dtype = q.dtype - if not (q.dtype == k.dtype == v.dtype): - k = k.to(q.dtype) - v = v.to(q.dtype) - fallback_q, fallback_k, fallback_v = q, k, v - if q.dtype not in (torch.float16, torch.bfloat16): - q = q.to(torch.bfloat16) - k = k.to(torch.bfloat16) - v = v.to(torch.bfloat16) - sm_scale = kwargs.get("softmax_scale") - if sm_scale is None: - sm_scale = 1.0 / math.sqrt(head_dim) - try: - out = sageattn_varlen( - q, - k, - v, - cu_seqlens_q.int(), - cu_seqlens_k.int(), - _var_attention_max_seqlen(cu_seqlens_q), - _var_attention_max_seqlen(cu_seqlens_k), - kwargs.get("causal", False), - sm_scale, - ) - except Exception as e: - logging.error("Error running sage var-len attention: %s, using pytorch var-len attention instead.", e) - out = var_attention_pytorch( - fallback_q, - fallback_k, - fallback_v, - heads, - cu_seqlens_q, - cu_seqlens_k, - skip_reshape=True, - skip_output_reshape=skip_output_reshape, - ) - if out.dtype != out_dtype: - out = out.to(out_dtype) - return out - if out.dtype != out_dtype: - out = out.to(out_dtype) - return _var_attention_output(out, heads, head_dim, skip_output_reshape) - -@torch._dynamo.disable -def var_attention_sage3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): - if not SAGE_ATTENTION3_IS_AVAILABLE: - if SAGE_ATTENTION_VARLEN_IS_AVAILABLE: - return var_attention_sage( - q, - k, - v, - heads, - cu_seqlens_q, - cu_seqlens_k, - skip_reshape=skip_reshape, - skip_output_reshape=skip_output_reshape, - **kwargs, - ) - return var_attention_pytorch( - q, - k, - v, - heads, - cu_seqlens_q, - cu_seqlens_k, - skip_reshape=skip_reshape, - skip_output_reshape=skip_output_reshape, - ) - q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) - seq_lens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - seq_lens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1] - uniform_q = bool((seq_lens_q == seq_lens_q[0]).all().item()) - uniform_k = bool((seq_lens_k == seq_lens_k[0]).all().item()) - if not (uniform_q and uniform_k and seq_lens_q[0] == seq_lens_k[0]): - if SAGE_ATTENTION_VARLEN_IS_AVAILABLE: - return var_attention_sage( - q, - k, - v, - heads, - cu_seqlens_q, - cu_seqlens_k, - skip_reshape=True, - skip_output_reshape=skip_output_reshape, - **kwargs, - ) - return var_attention_pytorch( - q, - k, - v, - heads, - cu_seqlens_q, - cu_seqlens_k, - skip_reshape=True, - skip_output_reshape=skip_output_reshape, - ) - out_dtype = q.dtype - if not (q.dtype == k.dtype == v.dtype): - k = k.to(q.dtype) - v = v.to(q.dtype) - fallback_q, fallback_k, fallback_v = q, k, v - if q.dtype not in (torch.float16, torch.bfloat16): - q = q.to(torch.bfloat16) - k = k.to(torch.bfloat16) - v = v.to(torch.bfloat16) - batch_size = len(cu_seqlens_q) - 1 - seq_len = int(seq_lens_q[0].item()) - q = q.view(batch_size, seq_len, heads, head_dim).transpose(1, 2) - k = k.view(batch_size, seq_len, heads, head_dim).transpose(1, 2) - v = v.view(batch_size, seq_len, heads, head_dim).transpose(1, 2) - try: - out = sageattn3_blackwell(q, k, v, is_causal=kwargs.get("causal", False)) - except Exception as e: - logging.error("Error running SageAttention3 var-len attention: %s, using fallback var-len attention instead.", e) - if SAGE_ATTENTION_VARLEN_IS_AVAILABLE: - return var_attention_sage( - fallback_q, - fallback_k, - fallback_v, - heads, - cu_seqlens_q, - cu_seqlens_k, - skip_reshape=True, - skip_output_reshape=skip_output_reshape, - **kwargs, - ) - return var_attention_pytorch( - fallback_q, - fallback_k, - fallback_v, - heads, - cu_seqlens_q, - cu_seqlens_k, - skip_reshape=True, - skip_output_reshape=skip_output_reshape, - ) - out = out.transpose(1, 2).reshape(-1, heads, head_dim).contiguous() - if out.dtype != out_dtype: - out = out.to(out_dtype) - return _var_attention_output(out, heads, head_dim, skip_output_reshape) - - -@torch._dynamo.disable -def var_attention_flash(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): - if not FLASH_ATTENTION_VARLEN_IS_AVAILABLE: - return var_attention_pytorch( - q, - k, - v, - heads, - cu_seqlens_q, - cu_seqlens_k, - skip_reshape=skip_reshape, - skip_output_reshape=skip_output_reshape, - ) - q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) - max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q) - max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k) - try: - out = flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=cu_seqlens_q.int(), - cu_seqlens_k=cu_seqlens_k.int(), - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=kwargs.get("dropout_p", 0.0), - causal=kwargs.get("causal", False), - deterministic=torch.are_deterministic_algorithms_enabled(), - ) - except Exception as e: - logging.error("Error running Flash Attention var-len attention: %s, using pytorch var-len attention instead.", e) - return var_attention_pytorch( - q, - k, - v, - heads, - cu_seqlens_q, - cu_seqlens_k, - skip_reshape=True, - skip_output_reshape=skip_output_reshape, - ) - return _var_attention_output(out, heads, head_dim, skip_output_reshape) - - -@torch._dynamo.disable -def var_attention_flash3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): - if not FLASH_ATTENTION3_IS_AVAILABLE: - return var_attention_pytorch( - q, - k, - v, - heads, - cu_seqlens_q, - cu_seqlens_k, - skip_reshape=skip_reshape, - skip_output_reshape=skip_output_reshape, - ) - q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) - max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q) - max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k) - try: - out = flash_attn3_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=cu_seqlens_q.int(), - cu_seqlens_k=cu_seqlens_k.int(), - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - seqused_q=None, - seqused_k=None, - softmax_scale=kwargs.get("softmax_scale"), - causal=kwargs.get("causal", False), - deterministic=torch.are_deterministic_algorithms_enabled(), - ) - except Exception as e: - logging.error("Error running Flash Attention 3 var-len attention: %s, using pytorch var-len attention instead.", e) - return var_attention_pytorch( - q, - k, - v, - heads, - cu_seqlens_q, - cu_seqlens_k, - skip_reshape=True, - skip_output_reshape=skip_output_reshape, - ) - if isinstance(out, tuple): - out = out[0] - return _var_attention_output(out, heads, head_dim, skip_output_reshape) - - -@torch._dynamo.disable -def var_attention_sub_quad(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): - return var_attention_pytorch( - q, - k, - v, - heads, - cu_seqlens_q, - cu_seqlens_k, - skip_reshape=skip_reshape, - skip_output_reshape=skip_output_reshape, - ) - - -@torch._dynamo.disable -def var_attention_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): - return var_attention_pytorch_split( - q, - k, - v, - heads, - cu_seqlens_q, - cu_seqlens_k, - skip_reshape=skip_reshape, - skip_output_reshape=skip_output_reshape, - ) - - -optimized_var_attention = var_attention_pytorch +optimized_var_attention = var_attention_optimized_split optimized_attention = attention_basic if model_management.sage_attention_enabled(): logging.info("Using sage attention") optimized_attention = attention_sage - if SAGE_ATTENTION3_IS_AVAILABLE and _use_blackwell_attention(): - logging.info("Using SageAttention3 for variable-length attention") - optimized_var_attention = var_attention_sage3 - elif SAGE_ATTENTION_VARLEN_IS_AVAILABLE: - logging.info("Using SageAttention for variable-length attention") - optimized_var_attention = var_attention_sage - else: - logging.info("Using pytorch attention for variable-length attention") - optimized_var_attention = var_attention_pytorch elif model_management.flash_attention_enabled(): logging.info("Using Flash Attention") optimized_attention = attention_flash - if FLASH_ATTENTION_VARLEN_IS_AVAILABLE and model_management.get_torch_device().type == "cuda": - logging.info("Using Flash Attention 2 for variable-length attention") - optimized_var_attention = var_attention_flash - else: - logging.info("Using pytorch attention for variable-length attention") - optimized_var_attention = var_attention_pytorch elif model_management.xformers_enabled(): logging.info("Using xformers attention") optimized_attention = attention_xformers elif model_management.pytorch_attention_enabled(): logging.info("Using pytorch attention") optimized_attention = attention_pytorch - optimized_var_attention = var_attention_pytorch else: if args.use_split_cross_attention: logging.info("Using split optimization for attention") optimized_attention = attention_split - optimized_var_attention = var_attention_split else: logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad - optimized_var_attention = var_attention_sub_quad + +logging.info("Using optimized_attention split-loop for variable-length attention") optimized_attention_masked = optimized_attention @@ -1193,25 +845,16 @@ optimized_attention_masked = optimized_attention # register core-supported attention functions if SAGE_ATTENTION_IS_AVAILABLE: register_attention_function("sage", attention_sage) -if SAGE_ATTENTION_VARLEN_IS_AVAILABLE: - register_attention_function("var_attention_sage", var_attention_sage) if SAGE_ATTENTION3_IS_AVAILABLE: register_attention_function("sage3", attention3_sage) - register_attention_function("var_attention_sage3", var_attention_sage3) if FLASH_ATTENTION_IS_AVAILABLE: register_attention_function("flash", attention_flash) -if FLASH_ATTENTION_VARLEN_IS_AVAILABLE: - register_attention_function("var_attention_flash", var_attention_flash) -if FLASH_ATTENTION3_IS_AVAILABLE: - register_attention_function("var_attention_flash3", var_attention_flash3) if model_management.xformers_enabled(): register_attention_function("xformers", attention_xformers) register_attention_function("pytorch", attention_pytorch) -register_attention_function("var_attention_pytorch", var_attention_pytorch) register_attention_function("sub_quad", attention_sub_quad) -register_attention_function("var_attention_sub_quad", var_attention_sub_quad) register_attention_function("split", attention_split) -register_attention_function("var_attention_split", var_attention_split) +register_attention_function("var_attention_optimized_split", var_attention_optimized_split) def optimized_attention_for_device(device, mask=False, small_input=False): diff --git a/comfy/ldm/seedvr/constants.py b/comfy/ldm/seedvr/constants.py index bfd72f1a2..ed1620436 100644 --- a/comfy/ldm/seedvr/constants.py +++ b/comfy/ldm/seedvr/constants.py @@ -40,7 +40,7 @@ SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path. # -------------------------------------------------------------------------------------- -# C. ByteDance config / source (BYTEDANCE - cite myseedvr2/SeedVR) +# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR) # -------------------------------------------------------------------------------------- BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm. BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift. diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 8f248a4d2..3fa9fe07e 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -58,10 +58,11 @@ class CustomRMSNorm(nn.Module): dims = tuple(range(-len(self.normalized_shape), 0)) - variance = input.pow(2).mean(dim=dims, keepdim=True) + normalized = input.float() + variance = normalized.pow(2).mean(dim=dims, keepdim=True) rms = torch.sqrt(variance + self.eps) - normalized = input / rms + normalized = normalized / rms if self.elementwise_affine: return normalized * self.weight.to(input.dtype) @@ -472,8 +473,8 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d): freqs = freqs.to(device=q.device) q = rearrange(q, "L h d -> h L d") k = rearrange(k, "L h d -> h L d") - q = _apply_rope1_partial(q, freqs) - k = _apply_rope1_partial(k, freqs) + q = _apply_seedvr2_rotary_emb(freqs, q.float()).to(q.dtype) + k = _apply_seedvr2_rotary_emb(freqs, k.float()).to(k.dtype) q = rearrange(q, "h L d -> L h d") k = rearrange(k, "h L d -> L h d") return q, k @@ -483,11 +484,20 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d): self, shape: torch.LongTensor, ) -> torch.Tensor: + # Primary provenance: ByteDance-Seed/SeedVR models/dit/rope.py builds + # 7B pixel RoPE with the interleaved-angle convention, not Comfy's + # Flux freqs_cis matrix. + plain_rope = RotaryEmbedding( + dim=self.rope.freqs.numel() * 2, + freqs_for="pixel", + max_freq=BYTEDANCE_ROPE_MAX_FREQ, + ) + plain_rope = plain_rope.to(self.rope.dummy.device) freq_list = [] for f, h, w in shape.tolist(): - freqs = self.get_axial_freqs(f, h, w) + freqs = plain_rope.get_axial_freqs(f, h, w) freq_list.append(freqs.view(-1, freqs.size(-1))) - return _to_flux_freqs_cis(torch.cat(freq_list, dim=0)) + return torch.cat(freq_list, dim=0) class MMRotaryEmbeddingBase(RotaryEmbeddingBase): @@ -556,6 +566,36 @@ def apply_rotary_emb( out = torch.cat((t_left, t_middle_out, t_right), dim=-1) return out.type(dtype) + +def _apply_seedvr2_rotary_emb( + freqs: torch.Tensor, + t: torch.Tensor, + start_index: int = 0, + scale: float = 1.0, + seq_dim: int = -2, + freqs_seq_dim: int | None = None, +) -> torch.Tensor: + dtype = t.dtype + if freqs_seq_dim is None and (freqs.ndim == 2 or t.ndim == 3): + freqs_seq_dim = 0 + + if t.ndim == 3 or freqs_seq_dim is not None: + seq_len = t.shape[seq_dim] + freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim) + + rot_feats = freqs.shape[-1] + end_index = start_index + rot_feats + + t_left = t[..., :start_index] + t_middle = t[..., start_index:end_index] + t_right = t[..., end_index:] + + freqs = freqs.to(device=t_middle.device, dtype=t_middle.dtype) + cos = freqs.cos() * scale + sin = freqs.sin() * scale + t_middle = (t_middle * cos) + (rotate_half(t_middle) * sin) + return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype) + def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor: """Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`.""" angles = freqs_interleaved[..., ::2].float() @@ -1380,6 +1420,8 @@ class NaDiT(nn.Module): **kwargs, ): self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM + if self._7b_version: + rope_type = "rope3d" self.dtype = dtype factory_kwargs = {"device": device, "dtype": dtype} window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index d6d07fe1c..68b11c0ff 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -24,12 +24,7 @@ from comfy.ldm.seedvr.constants import ( BYTEDANCE_VAE_SHIFTING_FACTOR, BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE, BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE, - CIELAB_DELTA, - CIELAB_KAPPA, - D65_WHITE_X, - D65_WHITE_Z, SEEDVR2_LATENT_CHANNELS, - WAVELET_DECOMP_LEVELS, ) from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.diffusionmodules.model import vae_attention diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index e48d9e463..9336c7d9e 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -575,7 +575,7 @@ class SeedVR2Conditioning(io.ComfyNode): description="Build SeedVR2 positive/negative conditioning from a VAE latent.", inputs=[ io.Model.Input("model", tooltip="The SeedVR2 model."), - io.Latent.Input("vae_conditioning", tooltip="The VAE-encoded latent to condition on."), + io.Latent.Input("vae_conditioning", display_name="LATENT", tooltip="The VAE-encoded latent to condition on."), ], outputs=[ io.Model.Output(display_name = "model"), diff --git a/tests-unit/comfy_test/test_seedvr2_internals.py b/tests-unit/comfy_test/test_seedvr2_internals.py index 60ce0c5b4..5b008ea6e 100644 --- a/tests-unit/comfy_test/test_seedvr2_internals.py +++ b/tests-unit/comfy_test/test_seedvr2_internals.py @@ -6,9 +6,7 @@ Sources (all merged verbatim, helper names disambiguated where colliding): apply_rotary_emb wrapper oracle at fp32. * GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare memory_occupy against get_norm_limit(), not float('inf'). - * var_attention backend registry. - * var_attention_pytorch SeedVR2-named guard — present-API shape contract - with AST-level pinning of the guard ordering. + * SeedVR2 variable-length attention split-loop contract. Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and comfy.ldm.modules.attention transitively pull in comfy.model_management, @@ -18,11 +16,6 @@ set first. from __future__ import annotations -import ast -import inspect -import logging -import textwrap -import warnings from unittest.mock import patch import pytest @@ -45,7 +38,7 @@ from comfy.ldm.seedvr.vae import ( # noqa: E402 causal_norm_wrapper, set_norm_limit, ) -from comfy.ldm.modules.attention import var_attention_pytorch # noqa: E402 +from comfy.ldm.modules.attention import var_attention_optimized_split # noqa: E402 # --------------------------------------------------------------------------- @@ -215,13 +208,14 @@ def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls): # --------------------------------------------------------------------------- -# var_attention backend tests (test_seedvr_var_attention_backends.py) +# SeedVR2 var_attention split-loop tests # --------------------------------------------------------------------------- def test_var_attention_registry_contains_always_available_entries(): - assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_pytorch"] is attention.var_attention_pytorch - assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_sub_quad"] is attention.var_attention_sub_quad - assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_split"] is attention.var_attention_split + assert ( + attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_optimized_split"] + is attention.var_attention_optimized_split + ) def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch): @@ -285,105 +279,63 @@ def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypa ) -# --------------------------------------------------------------------------- -# var_attention_pytorch SeedVR2 guard tests -# (test_var_attention_pytorch_seedvr2_guard.py) -# --------------------------------------------------------------------------- +def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch): + heads = 2 + head_dim = 3 + q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim) + k = q + 100 + v = q + 200 + cu = torch.tensor([0, 2, 5], dtype=torch.int32) + calls = [] -def _pytorch_guard_inputs(): - heads, head_dim, total_tokens = 2, 8, 6 - embed_dim = heads * head_dim - q = torch.randn(total_tokens, embed_dim) - k = torch.randn(total_tokens, embed_dim) - v = torch.randn(total_tokens, embed_dim) - cu = torch.tensor([0, 3, 6], dtype=torch.int32) - return q, k, v, heads, cu, cu, total_tokens, embed_dim + def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs): + calls.append( + { + "q_shape": tuple(q_arg.shape), + "k_shape": tuple(k_arg.shape), + "v_shape": tuple(v_arg.shape), + "heads": heads_arg, + "kwargs": kwargs, + } + ) + return q_arg + v_arg + monkeypatch.setattr(attention, "optimized_attention", fake_optimized_attention) -def _assert_guard_source_pin(): - src = textwrap.dedent(inspect.getsource(var_attention_pytorch)) - tree = ast.parse(src) - raise_lines = [] - nested_lines = [] - for node in ast.walk(tree): - if isinstance(node, ast.Raise) and isinstance(node.exc, ast.Call): - func = node.exc.func - if isinstance(func, ast.Name) and func.id == "RuntimeError": - raise_lines.append(node.lineno) - if isinstance(node, ast.Attribute) and node.attr == "nested_tensor_from_jagged": - nested_lines.append(node.lineno) - assert raise_lines, ( - "var_attention_pytorch has no `raise RuntimeError(...)` AST node; " - f"the SeedVR2-named guard is missing.\n--- source ---\n{src}" - ) - assert nested_lines, ( - "var_attention_pytorch source has no `nested_tensor_from_jagged` " - f"attribute access; cannot pin guard ordering.\n" - f"--- source ---\n{src}" - ) - first_raise = min(raise_lines) - first_nested = min(nested_lines) - assert first_raise < first_nested, ( - f"`raise RuntimeError(...)` first appears at line {first_raise}, " - f"but `torch.nested.nested_tensor_from_jagged` is referenced first " - f"at line {first_nested}; the guard must precede the lookup.\n" - f"--- source ---\n{src}" + out = var_attention_optimized_split( + q, + k, + v, + heads, + cu, + cu, + skip_reshape=True, + skip_output_reshape=True, ) - -def test_missing_api_raises_seedvr2_runtime_error(monkeypatch): - monkeypatch.delattr(torch.nested, "nested_tensor_from_jagged", raising=False) - q, k, v, heads, cu_q, cu_k, _, _ = _pytorch_guard_inputs() - - with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"): - var_attention_pytorch(q, k, v, heads, cu_q, cu_k) - - _assert_guard_source_pin() + assert tuple(out.shape) == (5, heads, head_dim) + assert len(calls) == 2 + assert calls[0]["q_shape"] == (1, heads, 2, head_dim) + assert calls[1]["q_shape"] == (1, heads, 3, head_dim) + assert all(call["heads"] == heads for call in calls) + assert all(call["kwargs"]["skip_reshape"] is True for call in calls) + assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls) + torch.testing.assert_close(out, q + v, rtol=0, atol=0) -def test_missing_namespace_raises_seedvr2_runtime_error(monkeypatch): - monkeypatch.delattr(torch, "nested", raising=False) - q, k, v, heads, cu_q, cu_k, _, _ = _pytorch_guard_inputs() +def test_var_attention_optimized_split_rejects_bad_offsets(): + q = torch.randn(5, 2, 3) + cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32) + cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32) - with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"): - var_attention_pytorch(q, k, v, heads, cu_q, cu_k) - - _assert_guard_source_pin() - - -def test_present_api_returns_expected_shape(): - q, k, v, heads, cu_q, cu_k, total_tokens, embed_dim = _pytorch_guard_inputs() - - torch_fx_logger = logging.getLogger("torch.fx._symbolic_trace") - old_torch_fx_level = torch_fx_logger.level - torch_fx_logger.setLevel(logging.ERROR) - try: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - message="The PyTorch API of nested tensors is in prototype stage.*", - category=UserWarning, - ) - out = var_attention_pytorch(q, k, v, heads, cu_q, cu_k) - finally: - torch_fx_logger.setLevel(old_torch_fx_level) - - assert tuple(out.shape) == (total_tokens, embed_dim), ( - f"expected ({total_tokens}, {embed_dim}); got {tuple(out.shape)}" - ) - - _assert_guard_source_pin() - - -def test_malformed_offsets_propagates_torch_runtime_error(): - q, k, v, heads, _, _, _, _ = _pytorch_guard_inputs() - cu_q_bad = torch.tensor([0, 3, 7], dtype=torch.int32) - cu_k_ok = torch.tensor([0, 3, 6], dtype=torch.int32) - - with pytest.raises(RuntimeError) as exc_info: - var_attention_pytorch(q, k, v, heads, cu_q_bad, cu_k_ok) - - msg = str(exc_info.value) - assert "SeedVR2" not in msg - - _assert_guard_source_pin() + with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"): + var_attention_optimized_split( + q, + q, + q, + 2, + cu_bad, + cu_ok, + skip_reshape=True, + skip_output_reshape=True, + ) diff --git a/tests-unit/comfy_test/test_seedvr2_model.py b/tests-unit/comfy_test/test_seedvr2_model.py index b81ff2d71..f2b9bcbbe 100644 --- a/tests-unit/comfy_test/test_seedvr2_model.py +++ b/tests-unit/comfy_test/test_seedvr2_model.py @@ -213,11 +213,11 @@ def test_seedvr2_7b_rope3d_matches_wrapper_oracle(): shape = torch.tensor([[1, 2, 2]], dtype=torch.long) freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1) - expected_q = seedvr_model.apply_rotary_emb( + expected_q = seedvr_model._apply_seedvr2_rotary_emb( freqs, q.permute(1, 0, 2).float(), ).to(q.dtype).permute(1, 0, 2) - expected_k = seedvr_model.apply_rotary_emb( + expected_k = seedvr_model._apply_seedvr2_rotary_emb( freqs, k.permute(1, 0, 2).float(), ).to(k.dtype).permute(1, 0, 2) diff --git a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py index 442480149..40079bbe2 100644 --- a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py +++ b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py @@ -174,10 +174,7 @@ def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size(): captured.update(kwargs) return torch.zeros(1, 3, 1, 16, 16) - with ( - patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae), - patch.object(vae_mod, "lab_color_transfer", side_effect=lambda content, style: content), - ): + with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae): wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling) assert captured["temporal_overlap"] == 7