diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 55360535a..8507557d5 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -32,6 +32,14 @@ except ImportError as e: raise e exit(-1) +SAGE_ATTENTION_VARLEN_IS_AVAILABLE = False +try: + from sageattention import sageattn_varlen + SAGE_ATTENTION_VARLEN_IS_AVAILABLE = True +except ImportError: + if model_management.sage_attention_enabled(): + logging.warning("SageAttention variable-length attention is unavailable, using pytorch var-len attention instead.") + SAGE_ATTENTION3_IS_AVAILABLE = False try: from sageattn3 import sageattn3_blackwell @@ -40,6 +48,7 @@ except ImportError: pass FLASH_ATTENTION_IS_AVAILABLE = False +FLASH_ATTENTION_VARLEN_IS_AVAILABLE = False try: from flash_attn import flash_attn_func FLASH_ATTENTION_IS_AVAILABLE = True @@ -48,6 +57,20 @@ except ImportError: logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") exit(-1) +try: + from flash_attn import flash_attn_varlen_func + FLASH_ATTENTION_VARLEN_IS_AVAILABLE = True +except ImportError: + if model_management.flash_attention_enabled() and FLASH_ATTENTION_IS_AVAILABLE: + logging.warning("Flash Attention variable-length attention is unavailable, using pytorch var-len attention instead.") + +FLASH_ATTENTION3_IS_AVAILABLE = False +try: + from flash_attn_interface import flash_attn_varlen_func as flash_attn3_varlen_func + FLASH_ATTENTION3_IS_AVAILABLE = True +except ImportError: + pass + REGISTERED_ATTENTION_FUNCTIONS = {} def register_attention_function(name: str, func: Callable): # avoid replacing existing functions @@ -735,28 +758,434 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) return out +_VAR_ATTENTION_NESTED_API_NAME = "nested_tensor_from_jagged" +_VAR_ATTENTION_GUARD_MESSAGE = ( + "SeedVR2 var_attention_pytorch: torch.nested.nested_tensor_from_jagged " + "is required by this attention path; the installed PyTorch build " + "does not provide it" +) + +def _var_attention_max_seqlen(cu_seqlens): + return int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item()) + + +def _var_attention_qkv(q, k, v, heads, skip_reshape): + if skip_reshape: + return q, k, v, q.shape[-1] + total_tokens, embed_dim = q.shape + head_dim = embed_dim // heads + return ( + q.view(total_tokens, heads, head_dim), + k.view(k.shape[0], heads, head_dim), + v.view(v.shape[0], heads, head_dim), + head_dim, + ) + + +def _var_attention_output(out, heads, head_dim, skip_output_reshape): + if skip_output_reshape: + return out + return out.reshape(-1, heads * head_dim) + + +def _use_blackwell_attention(): + device = model_management.get_torch_device() + if device.type != "cuda": + return False + major, minor = torch.cuda.get_device_capability(device) + return (major, minor) >= (12, 0) + + +def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False): + _nested = getattr(torch, "nested", None) + if _nested is None or not hasattr(_nested, _VAR_ATTENTION_NESTED_API_NAME): + raise RuntimeError(_VAR_ATTENTION_GUARD_MESSAGE) + + if not skip_reshape: + # assumes 2D q, k,v [total_tokens, embed_dim] + total_tokens, embed_dim = q.shape + head_dim = embed_dim // heads + q = q.view(total_tokens, heads, head_dim) + k = k.view(k.shape[0], heads, head_dim) + v = v.view(v.shape[0], heads, head_dim) + + q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long()) + k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long()) + v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long()) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + + out = out.transpose(1, 2) + if not skip_output_reshape: + return out.values().reshape(-1, heads * (q.shape[-1])) + return out.values() + + +def _validate_split_cu_seqlens(name, cu_seqlens, token_count): + if cu_seqlens.dtype not in (torch.int32, torch.int64): + raise ValueError(f"{name} must use an integer dtype") + if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2: + raise ValueError(f"{name} must be a 1D tensor with at least two offsets") + if cu_seqlens[0].item() != 0: + raise ValueError(f"{name} must start at 0") + if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item(): + raise ValueError(f"{name} must be strictly increasing") + if cu_seqlens[-1].item() != token_count: + raise ValueError(f"{name} does not match token count") + + +def _split_indices(cu_seqlens): + return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long) + + +def var_attention_pytorch_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False): + q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) + + _validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0]) + _validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0]) + if cu_seqlens_k[-1].item() != v.shape[0]: + raise ValueError("cu_seqlens_k does not match v token count") + + q_split_indices = _split_indices(cu_seqlens_q) + k_split_indices = _split_indices(cu_seqlens_k) + q_splits = torch.tensor_split(q, q_split_indices, dim=0) + k_splits = torch.tensor_split(k, k_split_indices, dim=0) + v_splits = torch.tensor_split(v, k_split_indices, dim=0) + if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits): + raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count") + + out = [] + for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits): + q_i = q_i.permute(1, 0, 2).unsqueeze(0) + k_i = k_i.permute(1, 0, 2).unsqueeze(0) + v_i = v_i.permute(1, 0, 2).unsqueeze(0) + out_i = comfy.ops.scaled_dot_product_attention(q_i, k_i, v_i, attn_mask=None, dropout_p=0.0, is_causal=False) + out.append(out_i.squeeze(0).permute(1, 0, 2)) + + out = torch.cat(out, dim=0) + return _var_attention_output(out, heads, head_dim, skip_output_reshape) + +@torch._dynamo.disable +def var_attention_sage(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): + if not SAGE_ATTENTION_VARLEN_IS_AVAILABLE: + return var_attention_pytorch( + q, + k, + v, + heads, + cu_seqlens_q, + cu_seqlens_k, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + ) + q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) + out_dtype = q.dtype + if not (q.dtype == k.dtype == v.dtype): + k = k.to(q.dtype) + v = v.to(q.dtype) + fallback_q, fallback_k, fallback_v = q, k, v + if q.dtype not in (torch.float16, torch.bfloat16): + q = q.to(torch.bfloat16) + k = k.to(torch.bfloat16) + v = v.to(torch.bfloat16) + sm_scale = kwargs.get("softmax_scale") + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + try: + out = sageattn_varlen( + q, + k, + v, + cu_seqlens_q.int(), + cu_seqlens_k.int(), + _var_attention_max_seqlen(cu_seqlens_q), + _var_attention_max_seqlen(cu_seqlens_k), + kwargs.get("causal", False), + sm_scale, + ) + except Exception as e: + logging.error("Error running sage var-len attention: %s, using pytorch var-len attention instead.", e) + out = var_attention_pytorch( + fallback_q, + fallback_k, + fallback_v, + heads, + cu_seqlens_q, + cu_seqlens_k, + skip_reshape=True, + skip_output_reshape=skip_output_reshape, + ) + if out.dtype != out_dtype: + out = out.to(out_dtype) + return out + if out.dtype != out_dtype: + out = out.to(out_dtype) + return _var_attention_output(out, heads, head_dim, skip_output_reshape) + + +@torch._dynamo.disable +def var_attention_sage3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): + if not SAGE_ATTENTION3_IS_AVAILABLE: + if SAGE_ATTENTION_VARLEN_IS_AVAILABLE: + return var_attention_sage( + q, + k, + v, + heads, + cu_seqlens_q, + cu_seqlens_k, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs, + ) + return var_attention_pytorch( + q, + k, + v, + heads, + cu_seqlens_q, + cu_seqlens_k, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + ) + q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) + seq_lens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seq_lens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1] + uniform_q = bool((seq_lens_q == seq_lens_q[0]).all().item()) + uniform_k = bool((seq_lens_k == seq_lens_k[0]).all().item()) + if not (uniform_q and uniform_k and seq_lens_q[0] == seq_lens_k[0]): + if SAGE_ATTENTION_VARLEN_IS_AVAILABLE: + return var_attention_sage( + q, + k, + v, + heads, + cu_seqlens_q, + cu_seqlens_k, + skip_reshape=True, + skip_output_reshape=skip_output_reshape, + **kwargs, + ) + return var_attention_pytorch( + q, + k, + v, + heads, + cu_seqlens_q, + cu_seqlens_k, + skip_reshape=True, + skip_output_reshape=skip_output_reshape, + ) + out_dtype = q.dtype + if not (q.dtype == k.dtype == v.dtype): + k = k.to(q.dtype) + v = v.to(q.dtype) + fallback_q, fallback_k, fallback_v = q, k, v + if q.dtype not in (torch.float16, torch.bfloat16): + q = q.to(torch.bfloat16) + k = k.to(torch.bfloat16) + v = v.to(torch.bfloat16) + batch_size = len(cu_seqlens_q) - 1 + seq_len = int(seq_lens_q[0].item()) + q = q.view(batch_size, seq_len, heads, head_dim).transpose(1, 2) + k = k.view(batch_size, seq_len, heads, head_dim).transpose(1, 2) + v = v.view(batch_size, seq_len, heads, head_dim).transpose(1, 2) + try: + out = sageattn3_blackwell(q, k, v, is_causal=kwargs.get("causal", False)) + except Exception as e: + logging.error("Error running SageAttention3 var-len attention: %s, using fallback var-len attention instead.", e) + if SAGE_ATTENTION_VARLEN_IS_AVAILABLE: + return var_attention_sage( + fallback_q, + fallback_k, + fallback_v, + heads, + cu_seqlens_q, + cu_seqlens_k, + skip_reshape=True, + skip_output_reshape=skip_output_reshape, + **kwargs, + ) + return var_attention_pytorch( + fallback_q, + fallback_k, + fallback_v, + heads, + cu_seqlens_q, + cu_seqlens_k, + skip_reshape=True, + skip_output_reshape=skip_output_reshape, + ) + out = out.transpose(1, 2).reshape(-1, heads, head_dim).contiguous() + if out.dtype != out_dtype: + out = out.to(out_dtype) + return _var_attention_output(out, heads, head_dim, skip_output_reshape) + + +@torch._dynamo.disable +def var_attention_flash(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): + if not FLASH_ATTENTION_VARLEN_IS_AVAILABLE: + return var_attention_pytorch( + q, + k, + v, + heads, + cu_seqlens_q, + cu_seqlens_k, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + ) + q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) + max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q) + max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k) + try: + out = flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q.int(), + cu_seqlens_k=cu_seqlens_k.int(), + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=kwargs.get("dropout_p", 0.0), + causal=kwargs.get("causal", False), + deterministic=torch.are_deterministic_algorithms_enabled(), + ) + except Exception as e: + logging.error("Error running Flash Attention var-len attention: %s, using pytorch var-len attention instead.", e) + return var_attention_pytorch( + q, + k, + v, + heads, + cu_seqlens_q, + cu_seqlens_k, + skip_reshape=True, + skip_output_reshape=skip_output_reshape, + ) + return _var_attention_output(out, heads, head_dim, skip_output_reshape) + + +@torch._dynamo.disable +def var_attention_flash3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): + if not FLASH_ATTENTION3_IS_AVAILABLE: + return var_attention_pytorch( + q, + k, + v, + heads, + cu_seqlens_q, + cu_seqlens_k, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + ) + q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) + max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q) + max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k) + try: + out = flash_attn3_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q.int(), + cu_seqlens_k=cu_seqlens_k.int(), + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=None, + seqused_k=None, + softmax_scale=kwargs.get("softmax_scale"), + causal=kwargs.get("causal", False), + deterministic=torch.are_deterministic_algorithms_enabled(), + ) + except Exception as e: + logging.error("Error running Flash Attention 3 var-len attention: %s, using pytorch var-len attention instead.", e) + return var_attention_pytorch( + q, + k, + v, + heads, + cu_seqlens_q, + cu_seqlens_k, + skip_reshape=True, + skip_output_reshape=skip_output_reshape, + ) + if isinstance(out, tuple): + out = out[0] + return _var_attention_output(out, heads, head_dim, skip_output_reshape) + + +@torch._dynamo.disable +def var_attention_sub_quad(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): + return var_attention_pytorch( + q, + k, + v, + heads, + cu_seqlens_q, + cu_seqlens_k, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + ) + + +@torch._dynamo.disable +def var_attention_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): + return var_attention_pytorch_split( + q, + k, + v, + heads, + cu_seqlens_q, + cu_seqlens_k, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + ) + + +optimized_var_attention = var_attention_pytorch optimized_attention = attention_basic if model_management.sage_attention_enabled(): logging.info("Using sage attention") optimized_attention = attention_sage + if SAGE_ATTENTION3_IS_AVAILABLE and _use_blackwell_attention(): + logging.info("Using SageAttention3 for variable-length attention") + optimized_var_attention = var_attention_sage3 + elif SAGE_ATTENTION_VARLEN_IS_AVAILABLE: + logging.info("Using SageAttention for variable-length attention") + optimized_var_attention = var_attention_sage + else: + logging.info("Using pytorch attention for variable-length attention") + optimized_var_attention = var_attention_pytorch elif model_management.flash_attention_enabled(): logging.info("Using Flash Attention") optimized_attention = attention_flash + if FLASH_ATTENTION_VARLEN_IS_AVAILABLE and model_management.get_torch_device().type == "cuda": + logging.info("Using Flash Attention 2 for variable-length attention") + optimized_var_attention = var_attention_flash + else: + logging.info("Using pytorch attention for variable-length attention") + optimized_var_attention = var_attention_pytorch elif model_management.xformers_enabled(): logging.info("Using xformers attention") optimized_attention = attention_xformers elif model_management.pytorch_attention_enabled(): logging.info("Using pytorch attention") optimized_attention = attention_pytorch + optimized_var_attention = var_attention_pytorch else: if args.use_split_cross_attention: logging.info("Using split optimization for attention") optimized_attention = attention_split + optimized_var_attention = var_attention_split else: logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad + optimized_var_attention = var_attention_sub_quad optimized_attention_masked = optimized_attention @@ -764,15 +1193,25 @@ optimized_attention_masked = optimized_attention # register core-supported attention functions if SAGE_ATTENTION_IS_AVAILABLE: register_attention_function("sage", attention_sage) +if SAGE_ATTENTION_VARLEN_IS_AVAILABLE: + register_attention_function("var_attention_sage", var_attention_sage) if SAGE_ATTENTION3_IS_AVAILABLE: register_attention_function("sage3", attention3_sage) + register_attention_function("var_attention_sage3", var_attention_sage3) if FLASH_ATTENTION_IS_AVAILABLE: register_attention_function("flash", attention_flash) +if FLASH_ATTENTION_VARLEN_IS_AVAILABLE: + register_attention_function("var_attention_flash", var_attention_flash) +if FLASH_ATTENTION3_IS_AVAILABLE: + register_attention_function("var_attention_flash3", var_attention_flash3) if model_management.xformers_enabled(): register_attention_function("xformers", attention_xformers) register_attention_function("pytorch", attention_pytorch) +register_attention_function("var_attention_pytorch", var_attention_pytorch) register_attention_function("sub_quad", attention_sub_quad) +register_attention_function("var_attention_sub_quad", var_attention_sub_quad) register_attention_function("split", attention_split) +register_attention_function("var_attention_split", var_attention_split) def optimized_attention_for_device(device, mask=False, small_input=False): @@ -1209,5 +1648,3 @@ class SpatialVideoTransformer(SpatialTransformer): x = self.proj_out(x) out = x + x_in return out - - diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index fcbaa074f..235df0b83 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -13,6 +13,7 @@ if model_management.xformers_enabled_vae(): import xformers import xformers.ops + def torch_cat_if_needed(xl, dim): xl = [x for x in xl if x is not None and x.shape[dim] > 0] if len(xl) > 1: @@ -22,7 +23,8 @@ def torch_cat_if_needed(xl, dim): else: return None -def get_timestep_embedding(timesteps, embedding_dim): + +def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. @@ -33,11 +35,13 @@ def get_timestep_embedding(timesteps, embedding_dim): assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) + emb = math.log(10000) / (half_dim - downscale_freq_shift) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0,1,0,0)) return emb diff --git a/comfy/samplers.py b/comfy/samplers.py old mode 100755 new mode 100644