import torch from comfy.ldm.modules import attention as _attention def _var_attention_qkv(q, k, v, heads, skip_reshape): if skip_reshape: return q, k, v, q.shape[-1] total_tokens, embed_dim = q.shape head_dim = embed_dim // heads return ( q.view(total_tokens, heads, head_dim), k.view(k.shape[0], heads, head_dim), v.view(v.shape[0], heads, head_dim), head_dim, ) def _var_attention_output(out, heads, head_dim, skip_output_reshape): if skip_output_reshape: return out return out.reshape(-1, heads * head_dim) def _validate_split_cu_seqlens(name, cu_seqlens, token_count): if cu_seqlens.dtype not in (torch.int32, torch.int64): raise ValueError(f"{name} must use an integer dtype") if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2: raise ValueError(f"{name} must be a 1D tensor with at least two offsets") if cu_seqlens[0].item() != 0: raise ValueError(f"{name} must start at 0") if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item(): raise ValueError(f"{name} must be strictly increasing") if cu_seqlens[-1].item() != token_count: raise ValueError(f"{name} does not match token count") def _split_indices(cu_seqlens): return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long) def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) _validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0]) _validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0]) if cu_seqlens_k[-1].item() != v.shape[0]: raise ValueError("cu_seqlens_k does not match v token count") q_split_indices = _split_indices(cu_seqlens_q) k_split_indices = _split_indices(cu_seqlens_k) q_splits = torch.tensor_split(q, q_split_indices, dim=0) k_splits = torch.tensor_split(k, k_split_indices, dim=0) v_splits = torch.tensor_split(v, k_split_indices, dim=0) if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits): raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count") out = [] for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits): q_i = q_i.permute(1, 0, 2).unsqueeze(0) k_i = k_i.permute(1, 0, 2).unsqueeze(0) v_i = v_i.permute(1, 0, 2).unsqueeze(0) out_dtype = q_i.dtype if _attention.optimized_attention is _attention.attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16): q_i = q_i.to(torch.bfloat16) k_i = k_i.to(torch.bfloat16) v_i = v_i.to(torch.bfloat16) out_i = _attention.optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True) if out_i.dtype != out_dtype: out_i = out_i.to(out_dtype) out.append(out_i.squeeze(0).permute(1, 0, 2)) out = torch.cat(out, dim=0) return _var_attention_output(out, heads, head_dim, skip_output_reshape) optimized_var_attention = var_attention_optimized_split