mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-16 01:00:49 +08:00
1253 lines
56 KiB
Python
1253 lines
56 KiB
Python
from __future__ import annotations
|
|
|
|
from comfy.text_encoders.llama import (
|
|
RMSNorm, MLP, Attention, LlamaRoPE, Llama2Config
|
|
)
|
|
|
|
from comfy.autoregressive_sampling import GenerationConfig, apply_logits_processing, check_stopping_criteria, Cache, StaticCache, DynamicCache
|
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
from .preprocess import _ceil_to_nearest
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from enum import Enum
|
|
from dataclasses import dataclass
|
|
from collections import OrderedDict
|
|
from typing import Optional, Tuple, Union, List
|
|
|
|
class GenerationMode(Enum):
|
|
TEXT = 0
|
|
AUDIO_INIT = 1
|
|
AUDIO_IN_PROGRESS = 2
|
|
|
|
def _ignore_causal_mask_sdpa(
|
|
attention_mask: Optional[torch.Tensor],
|
|
inputs_embeds: torch.Tensor,
|
|
past_key_values_length: int,
|
|
sliding_window: Optional[int] = None,
|
|
is_training: bool = False,
|
|
) -> bool:
|
|
|
|
_, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
|
|
key_value_length = query_length + past_key_values_length
|
|
|
|
ignore_causal_mask = False
|
|
|
|
if attention_mask is None:
|
|
if (is_training and (query_length == 1 or key_value_length == query_length)):
|
|
ignore_causal_mask = True
|
|
elif sliding_window is None or key_value_length < sliding_window:
|
|
if len(attention_mask.shape) == 4:
|
|
return False
|
|
elif torch.all(attention_mask == 1):
|
|
if query_length == 1 or key_value_length == query_length:
|
|
ignore_causal_mask = True
|
|
|
|
return ignore_causal_mask
|
|
|
|
def categorical_sample(probs, generator = None):
|
|
u = torch.rand((probs.size(0), 1), device = probs.device, generator = generator)
|
|
cdf = probs.cumsum(dim = -1)
|
|
return (u < cdf).float().argmax(dim = -1)
|
|
|
|
@dataclass
|
|
class HiggsAudioModelOutputWithPast:
|
|
logits: Optional[torch.FloatTensor] = None
|
|
audio_logits: Optional[torch.FloatTensor] = None
|
|
past_key_values: Optional[Cache] = None
|
|
audio_in_discrete_codes_mask: Optional[torch.BoolTensor] = None
|
|
audio_out_mask: Optional[torch.BoolTensor] = None
|
|
|
|
@torch.jit.script
|
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
attention_mask: torch.Tensor,
|
|
sequence_length: int,
|
|
target_length: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
min_dtype: float,
|
|
cache_position: torch.Tensor,
|
|
batch_size: int,
|
|
):
|
|
if attention_mask is not None and attention_mask.dim() == 4:
|
|
causal_mask = attention_mask
|
|
else:
|
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
|
if sequence_length != 1:
|
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
|
if attention_mask is not None:
|
|
causal_mask = causal_mask.clone()
|
|
mask_length = attention_mask.shape[-1]
|
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
|
padding_mask = padding_mask == 0
|
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
padding_mask, min_dtype
|
|
)
|
|
|
|
return causal_mask
|
|
|
|
# refactored to decrease the length
|
|
def merge_input_ids_with_audio_features(
|
|
audio_in_embed, audio_in_ids_start, audio_out_embed, audio_out_ids_start,
|
|
audio_in_token_idx, audio_out_token_idx, inputs_embeds, input_ids,
|
|
attention_mask, label_ids, pad_token_id, ignore_index=-100,
|
|
round_to=8, left_padding=True,
|
|
):
|
|
def compute_audio_codes_length(ids_start, embed):
|
|
return torch.concat([
|
|
ids_start[1:] - ids_start[:-1],
|
|
torch.tensor([embed.shape[0] - ids_start[-1]], device=ids_start.device, dtype=torch.long),
|
|
], dim=0).long()
|
|
|
|
def fill_audio_embeddings(final_embedding, final_input_ids, final_labels, final_audio_mask,
|
|
embed, token_idx, ids_start, codes_length, token_ends, batch_id,
|
|
skip_labels, ignore_index):
|
|
seq_indices = torch.arange(max_token_num, device=target_device).unsqueeze(0).expand(ids_start.shape[0], max_token_num)
|
|
token_starts = token_ends - codes_length + 1
|
|
batch_indices, col_indices = torch.where((seq_indices >= token_starts.unsqueeze(1)) & (seq_indices <= token_ends.unsqueeze(1)))
|
|
batch_indices = batch_id[batch_indices]
|
|
|
|
if embed.dtype != final_embedding.dtype:
|
|
embed = embed.to(final_embedding.dtype)
|
|
|
|
final_embedding[batch_indices, col_indices] = embed
|
|
final_input_ids[batch_indices, col_indices] = token_idx
|
|
if not skip_labels: final_labels[batch_indices, col_indices] = ignore_index
|
|
final_audio_mask[batch_indices, col_indices] = True
|
|
|
|
skip_labels = label_ids is None
|
|
if audio_in_embed is not None and audio_in_embed.shape[0] == 0: audio_in_embed = None
|
|
if audio_out_embed is not None and audio_out_embed.shape[0] == 0: audio_out_embed = None
|
|
|
|
batch_size, sequence_length, embed_dim = inputs_embeds.shape
|
|
target_device = inputs_embeds.device
|
|
if left_padding is None: left_padding = torch.any(attention_mask[:, 0] == 0)
|
|
|
|
audio_in_token_mask, audio_out_token_mask = input_ids == audio_in_token_idx, input_ids == audio_out_token_idx
|
|
text_token_mask = (input_ids != audio_in_token_idx) & (input_ids != audio_out_token_idx)
|
|
token_placeholder_num = torch.ones_like(input_ids)
|
|
|
|
if audio_in_embed is not None:
|
|
audio_in_codes_length = compute_audio_codes_length(audio_in_ids_start, audio_in_embed)
|
|
token_placeholder_num[audio_in_token_mask] = audio_in_codes_length
|
|
if audio_out_embed is not None:
|
|
audio_out_codes_length = compute_audio_codes_length(audio_out_ids_start, audio_out_embed)
|
|
token_placeholder_num[audio_out_token_mask] = audio_out_codes_length
|
|
|
|
new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1
|
|
max_token_num = _ceil_to_nearest(token_placeholder_num.sum(-1).max(), round_to)
|
|
nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1]
|
|
if left_padding: new_token_positions += nb_audio_pad[:, None]
|
|
|
|
final_embedding = torch.zeros((batch_size, max_token_num, embed_dim), dtype=inputs_embeds.dtype, device=target_device)
|
|
final_attention_mask = torch.zeros((batch_size, max_token_num), dtype=attention_mask.dtype, device=target_device)
|
|
final_input_ids = torch.full((batch_size, max_token_num), pad_token_id, dtype=input_ids.dtype, device=target_device)
|
|
final_labels = None if skip_labels else torch.full((batch_size, max_token_num), ignore_index, dtype=label_ids.dtype, device=target_device)
|
|
final_audio_in_mask = torch.zeros((batch_size, max_token_num), dtype=torch.bool, device=target_device)
|
|
final_audio_in_discrete_codes_mask = torch.zeros((batch_size, max_token_num), dtype=torch.bool, device=target_device)
|
|
final_audio_out_mask = torch.zeros((batch_size, max_token_num), dtype=torch.bool, device=target_device)
|
|
|
|
batch_id = torch.arange(batch_size, device=target_device).unsqueeze(1).expand(batch_size, sequence_length)
|
|
audio_in_batch_id, audio_out_batch_id = batch_id[audio_in_token_mask], batch_id[audio_out_token_mask]
|
|
audio_in_token_ends, audio_out_token_ends = new_token_positions[audio_in_token_mask], new_token_positions[audio_out_token_mask]
|
|
|
|
if audio_in_embed is not None:
|
|
fill_audio_embeddings(final_embedding, final_input_ids, final_labels, final_audio_in_mask,
|
|
audio_in_embed, audio_in_token_idx, audio_in_ids_start,
|
|
audio_in_codes_length, audio_in_token_ends, audio_in_batch_id,
|
|
skip_labels, ignore_index)
|
|
final_audio_in_discrete_codes_mask = final_audio_in_mask.clone()
|
|
|
|
if audio_out_embed is not None:
|
|
fill_audio_embeddings(final_embedding, final_input_ids, final_labels, final_audio_out_mask,
|
|
audio_out_embed, audio_out_token_idx, audio_out_ids_start,
|
|
audio_out_codes_length, audio_out_token_ends, audio_out_batch_id,
|
|
skip_labels, ignore_index)
|
|
|
|
batch_indices, text_indices = torch.where(text_token_mask)
|
|
text_to_overwrite = new_token_positions[batch_indices, text_indices]
|
|
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, text_indices]
|
|
if not skip_labels: final_labels[batch_indices, text_to_overwrite] = label_ids[batch_indices, text_indices]
|
|
final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, text_indices]
|
|
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, text_indices]
|
|
|
|
final_attention_mask |= final_audio_in_mask | final_audio_out_mask
|
|
|
|
if left_padding:
|
|
first_non_zero_loc = (final_attention_mask.sum(0).nonzero()[0] // round_to) * round_to
|
|
if first_non_zero_loc > 0:
|
|
final_attention_mask = final_attention_mask[:, first_non_zero_loc:]
|
|
final_embedding = final_embedding[:, first_non_zero_loc:]
|
|
if not skip_labels: final_labels = final_labels[:, first_non_zero_loc:]
|
|
final_input_ids = final_input_ids[:, first_non_zero_loc:]
|
|
final_audio_in_mask = final_audio_in_mask[:, first_non_zero_loc:]
|
|
final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, first_non_zero_loc:]
|
|
final_audio_out_mask = final_audio_out_mask[:, first_non_zero_loc:]
|
|
else:
|
|
last_non_zero_loc = ((final_attention_mask.sum(0).nonzero()[-1] + 1 + round_to - 1) // round_to) * round_to
|
|
if last_non_zero_loc < max_token_num:
|
|
final_attention_mask = final_attention_mask[:, :last_non_zero_loc]
|
|
final_embedding = final_embedding[:, :last_non_zero_loc]
|
|
if not skip_labels: final_labels = final_labels[:, :last_non_zero_loc]
|
|
final_input_ids = final_input_ids[:, :last_non_zero_loc]
|
|
final_audio_in_mask = final_audio_in_mask[:, :last_non_zero_loc]
|
|
final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, :last_non_zero_loc]
|
|
final_audio_out_mask = final_audio_out_mask[:, :last_non_zero_loc]
|
|
|
|
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(final_attention_mask == 0, 1)
|
|
|
|
return (final_embedding, final_attention_mask, final_labels, position_ids, final_input_ids,
|
|
final_audio_in_mask, final_audio_in_discrete_codes_mask, final_audio_out_mask)
|
|
|
|
class HiggsAudioDualFFNDecoderLayer(nn.Module):
|
|
|
|
def __init__(
|
|
self, config, llama_config, layer_idx: int, fast_forward: bool = False, device = None, dtype = None,
|
|
):
|
|
super().__init__()
|
|
text_config = config["text_config"]
|
|
self.hidden_size = text_config["hidden_size"]
|
|
self.layer_idx = layer_idx
|
|
self.self_attn = Attention(config = llama_config, layer_idx = layer_idx, device = device, dtype = dtype)
|
|
|
|
self.mlp = MLP(llama_config)
|
|
|
|
if not fast_forward:
|
|
self.audio_mlp = MLP(llama_config, device = device, dtype = dtype)
|
|
self.audio_input_layernorm = RMSNorm(text_config["hidden_size"], eps = text_config["rms_norm_eps"], device = device, dtype = dtype)
|
|
self.audio_post_attention_layernorm = RMSNorm(text_config["hidden_size"], eps = text_config["rms_norm_eps"], device = device, dtype = dtype)
|
|
|
|
self.fast_forward = fast_forward
|
|
|
|
self.input_layernorm = RMSNorm(text_config["hidden_size"], eps = text_config["rms_norm_eps"], device = device, dtype = dtype)
|
|
self.post_attention_layernorm = RMSNorm(text_config["hidden_size"], eps = text_config["rms_norm_eps"], device = device, dtype = dtype)
|
|
|
|
self.text_config = text_config
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
fast_forward_attention_mask: Optional[torch.Tensor] = None,
|
|
audio_out_mask: Optional[torch.BoolTensor] = None,
|
|
is_decoding_audio_token: Optional[bool] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
use_cache: Optional[bool] = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
is_using_cuda_graph: Optional[bool] = False,
|
|
position_embeddings = None,
|
|
**kwargs,
|
|
):
|
|
|
|
residual = hidden_states
|
|
target_length = hidden_states.shape[1]
|
|
use_static_cache = isinstance(past_key_value, StaticCache)
|
|
decode_stage = hidden_states.shape[1] == 1
|
|
if is_using_cuda_graph:
|
|
assert decode_stage and use_static_cache, (
|
|
"The CUDA graph mode should only be used in the decoding stage with static cache."
|
|
)
|
|
|
|
if is_decoding_audio_token and self.fast_forward:
|
|
return (hidden_states,)
|
|
|
|
has_audio_out = audio_out_mask is not None and audio_out_mask.shape[0] > 0
|
|
|
|
audio_out_mask_sq = audio_out_mask
|
|
|
|
small_input = target_length <= 2048
|
|
optimized_attention = optimized_attention_for_device(hidden_states.device, small_input = small_input)
|
|
|
|
if self.fast_forward and has_audio_out:
|
|
original_hidden_states = hidden_states.clone()
|
|
min_dtype = torch.finfo(hidden_states.dtype).min
|
|
if attention_mask is None:
|
|
attention_mask = ~audio_out_mask
|
|
|
|
if optimized_attention.__name__ != "attention_flash":
|
|
sequence_length = audio_out_mask.shape[1]
|
|
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
|
attention_mask=attention_mask,
|
|
sequence_length=sequence_length,
|
|
target_length=sequence_length,
|
|
dtype=hidden_states.dtype,
|
|
min_dtype=min_dtype,
|
|
device=hidden_states.device,
|
|
cache_position=cache_position,
|
|
batch_size=hidden_states.shape[0],
|
|
)
|
|
if use_cache:
|
|
attention_mask = attention_mask[:, :, -target_length:, :]
|
|
elif len(attention_mask.shape) == 2:
|
|
attention_mask = attention_mask * ~audio_out_mask
|
|
elif len(attention_mask.shape) == 4:
|
|
|
|
if use_static_cache:
|
|
attention_mask = fast_forward_attention_mask
|
|
else:
|
|
if use_cache:
|
|
attention_mask = attention_mask.masked_fill(
|
|
audio_out_mask[:, -target_length:].reshape(audio_out_mask.shape[0], 1, target_length, 1)
|
|
| audio_out_mask.reshape(audio_out_mask.shape[0], 1, 1, audio_out_mask.shape[1]),
|
|
min_dtype,
|
|
)
|
|
else:
|
|
attention_mask = attention_mask.masked_fill(
|
|
audio_out_mask.reshape(audio_out_mask.shape[0], 1, audio_out_mask.shape[1], 1)
|
|
| audio_out_mask.reshape(audio_out_mask.shape[0], 1, 1, audio_out_mask.shape[1]),
|
|
min_dtype,
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported attention_mask format, attention_mask={attention_mask}")
|
|
|
|
if (
|
|
optimized_attention.__name__ == "attention_pytorch"
|
|
and attention_mask is not None
|
|
and attention_mask.device.type == "cuda"
|
|
):
|
|
attention_mask = attention_mask.mul(~torch.all(attention_mask == min_dtype, dim=-1, keepdim=True))
|
|
|
|
if has_audio_out and not self.fast_forward:
|
|
if use_cache:
|
|
hidden_states = torch.where(
|
|
audio_out_mask_sq[:, -target_length:].unsqueeze(-1),
|
|
self.audio_input_layernorm(hidden_states),
|
|
self.input_layernorm(hidden_states),
|
|
)
|
|
else:
|
|
hidden_states = torch.where(
|
|
audio_out_mask_sq.unsqueeze(-1),
|
|
self.audio_input_layernorm(hidden_states),
|
|
self.input_layernorm(hidden_states),
|
|
)
|
|
else:
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
hidden_states, present_key_value = self.self_attn(
|
|
hidden_states = hidden_states,
|
|
attention_mask = attention_mask,
|
|
freqs_cis = position_embeddings,
|
|
past_key_value=past_key_value,
|
|
optimized_attention = optimized_attention,
|
|
cache_position=cache_position,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
if has_audio_out and not self.fast_forward:
|
|
if use_cache:
|
|
real_audio_out_mask = audio_out_mask_sq[:, -target_length:]
|
|
else:
|
|
real_audio_out_mask = audio_out_mask_sq
|
|
|
|
if decode_stage and is_using_cuda_graph:
|
|
assert is_decoding_audio_token is not None, (
|
|
"is_decoding_audio_token should be present in the decoding stage."
|
|
)
|
|
if is_decoding_audio_token:
|
|
hidden_states = self.audio_post_attention_layernorm(hidden_states)
|
|
hidden_states = self.audio_mlp(hidden_states)
|
|
else:
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
residual = residual + hidden_states
|
|
else:
|
|
text_hidden_states = self.post_attention_layernorm(hidden_states[~real_audio_out_mask])
|
|
audio_hidden_states = self.audio_post_attention_layernorm(hidden_states[real_audio_out_mask])
|
|
|
|
mlp_dtype = next(iter(self.mlp.parameters())).dtype
|
|
if text_hidden_states.dtype != mlp_dtype:
|
|
text_hidden_states = text_hidden_states.to(mlp_dtype)
|
|
|
|
text_hidden_states = self.mlp(text_hidden_states)
|
|
residual[~real_audio_out_mask] += text_hidden_states
|
|
|
|
audio_hidden_states = self.audio_mlp(audio_hidden_states)
|
|
residual[real_audio_out_mask] += audio_hidden_states
|
|
|
|
hidden_states = residual
|
|
else:
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
if self.fast_forward and has_audio_out:
|
|
if use_cache:
|
|
hidden_states = torch.where(
|
|
audio_out_mask_sq[:, -target_length:].unsqueeze(-1), original_hidden_states, hidden_states
|
|
)
|
|
else:
|
|
hidden_states = torch.where(audio_out_mask_sq.unsqueeze(-1), original_hidden_states, hidden_states)
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if use_cache:
|
|
outputs += (present_key_value,)
|
|
|
|
return outputs
|
|
|
|
class HiggsAudioDecoderProjector(nn.Module):
|
|
|
|
def __init__(self, config, device = None, dtype = None, operations = None):
|
|
super().__init__()
|
|
|
|
self.text_lm_head = operations.Linear(config["text_config"]["hidden_size"], config["text_config"]["vocab_size"], bias=False, device = device, dtype = dtype)
|
|
self.audio_lm_head = operations.Linear(
|
|
config["text_config"]["hidden_size"], config["audio_num_codebooks"] * (config["audio_codebook_size"] + 2), bias=False, device = device, dtype = dtype
|
|
)
|
|
|
|
def forward(self, hidden_states, audio_out_mask, **kwargs):
|
|
logits = self.text_lm_head(hidden_states)
|
|
audio_logits = self.audio_lm_head(hidden_states[audio_out_mask])
|
|
return logits, audio_logits
|
|
|
|
class HiggsAudioModel(nn.Module):
|
|
|
|
_supports_cache_class = True
|
|
_supports_static_cache = True
|
|
|
|
def __init__(self, device = None, dtype = None, operations = None, **kwargs):
|
|
super().__init__()
|
|
|
|
self.padding_idx = kwargs["pad_token_id"]
|
|
self.audio_in_token_idx = kwargs["audio_in_token_idx"]
|
|
self.audio_out_token_idx = kwargs["audio_out_token_idx"]
|
|
self.audio_out_bos_token_id = kwargs.get("audio_out_bos_token_id", None)
|
|
self.audio_eos_token_id = kwargs.get("audio_eos_token_id", None)
|
|
self.vocab_size = kwargs["text_config"]["vocab_size"]
|
|
self.audio_num_codebooks = kwargs["audio_num_codebooks"]
|
|
self.use_delay_pattern = kwargs["use_delay_pattern"]
|
|
|
|
# for autoregressive sampling
|
|
self.num_hidden_layers = kwargs["text_config"]["num_hidden_layers"]
|
|
self.cache_config = kwargs["text_config"]
|
|
self.hidden_dim = kwargs["text_config"]["hidden_size"]
|
|
self.max_seq_len = kwargs["text_config"]["max_position_embeddings"]
|
|
self.cache_implementation = "static"
|
|
self.use_kv_buckets = kwargs.get("use_kv_buckets", False)
|
|
|
|
self.dtype = dtype
|
|
self.device = device
|
|
self.config = kwargs
|
|
|
|
|
|
self.audio_out_bos_token_id = 128013
|
|
self.audio_eos_token_id = 128012
|
|
|
|
text_config = kwargs["text_config"]
|
|
llama_config = Llama2Config(num_attention_heads = text_config["num_attention_heads"],
|
|
num_key_value_heads = text_config["num_key_value_heads"],
|
|
hidden_size = text_config["hidden_size"],
|
|
head_dim = text_config["head_dim"],
|
|
qkv_bias = text_config["mlp_bias"],
|
|
intermediate_size = text_config["intermediate_size"])
|
|
|
|
self.embed_tokens = operations.Embedding(self.vocab_size, kwargs["text_config"]["hidden_size"], self.padding_idx, device = device, dtype = dtype)
|
|
self.attn_implementation = optimized_attention.__name__
|
|
|
|
if kwargs["audio_adapter_type"] == "dual_ffn_fast_forward":
|
|
layer_idx = 0
|
|
layers = []
|
|
for j in range(self.num_hidden_layers):
|
|
if j in kwargs["audio_dual_ffn_layers"]:
|
|
layers.append(
|
|
HiggsAudioDualFFNDecoderLayer(
|
|
kwargs,
|
|
llama_config,
|
|
layer_idx,
|
|
fast_forward=False,
|
|
device = device, dtype = dtype
|
|
)
|
|
)
|
|
layer_idx += 1
|
|
else:
|
|
layers.append(
|
|
HiggsAudioDualFFNDecoderLayer(kwargs, llama_config, layer_idx, fast_forward=True, device = device, dtype = dtype)
|
|
)
|
|
layer_idx += 1
|
|
self.layers = nn.ModuleList(layers)
|
|
else:
|
|
raise NotImplementedError(f"Audio adapter type {kwargs['audio_adapter_type']} not implemented.")
|
|
|
|
self.num_activation_checkpointing_layers = len(self.layers)
|
|
self.decode_graph_runners = []
|
|
#self.decode_graph_runners = defaultdict(dict[bool, CUDAGraphRunner])
|
|
self.norm = RMSNorm(
|
|
kwargs["text_config"]["hidden_size"], eps = kwargs["text_config"]["rms_norm_eps"]
|
|
)
|
|
self.rotary_emb = LlamaRoPE(config = llama_config)
|
|
|
|
self.audio_tower = None
|
|
self.audio_encoder_proj = None
|
|
|
|
self.audio_decoder_proj = HiggsAudioDecoderProjector(
|
|
kwargs, device=device, dtype=dtype, operations=operations
|
|
)
|
|
self.audio_codebook_size = kwargs["audio_codebook_size"] + 2
|
|
|
|
self.audio_codebook_embeddings = operations.Embedding(
|
|
kwargs["audio_num_codebooks"] * self.audio_codebook_size,
|
|
kwargs["text_config"]["hidden_size"],
|
|
)
|
|
|
|
self.audio_codebook_weights = (
|
|
torch.ones(kwargs["audio_num_codebooks"]) / kwargs["audio_num_codebooks"]
|
|
)
|
|
|
|
self.stop_strings = [[128009], [128001]]
|
|
|
|
def _sample_audio_tokens(
|
|
self,
|
|
audio_logits: torch.Tensor,
|
|
audio_out_ids: torch.Tensor,
|
|
logits_processing_list,
|
|
device: torch.device,
|
|
torch_generator: Optional[torch.Generator],
|
|
generation_config: GenerationConfig,
|
|
num_delay: int,
|
|
num_remaining_delays: Optional[int],
|
|
is_using_cuda_graphs,
|
|
do_sample = False,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[int]]:
|
|
"""Sample audio tokens and its corresponding text tokens from the logits"""
|
|
|
|
ras_win_len = generation_config.generation_kwargs.get("ras_win_len", None)
|
|
ras_win_max_num_repeat = generation_config.generation_kwargs.get("ras_win_max_num_repeat", 2)
|
|
audio_eos_token_id = generation_config.generation_kwargs.get("audio_eos_token_id", None)
|
|
|
|
next_audio_token_logits = audio_logits.clone()[-1, :, :].float().to(device)
|
|
next_audio_token_scores = apply_logits_processing(None, next_audio_token_logits, logits_processing_list)
|
|
|
|
if do_sample:
|
|
probs = nn.functional.softmax(next_audio_token_scores, dim = -1)
|
|
# torch.multinomial doesn't work with cuda graphs, replaced with mathematically eqv. fn
|
|
if not is_using_cuda_graphs:
|
|
next_audio_tokens = torch.multinomial(probs, num_samples = 1, generator=torch_generator).squeeze(1)
|
|
else:
|
|
next_audio_tokens = categorical_sample(probs, generator = torch_generator)
|
|
else:
|
|
next_audio_tokens = torch.argmax(next_audio_token_scores, dim=-1)
|
|
|
|
if ras_win_len is not None:
|
|
rep_num = (audio_out_ids[:, -ras_win_len:] == next_audio_tokens.unsqueeze(1)).sum(dim=1)
|
|
|
|
row_indices = torch.nonzero(rep_num >= ras_win_max_num_repeat).squeeze(1)
|
|
resampled_next_tokens = (
|
|
next_audio_token_logits[row_indices]
|
|
.softmax(dim=-1)
|
|
.multinomial(1, replacement=True, generator=torch_generator)
|
|
.squeeze(1)
|
|
)
|
|
next_audio_tokens[row_indices] = resampled_next_tokens
|
|
|
|
# Force the next text tokens to be <|AUDIO_OUT|> in audio generation mode
|
|
next_tokens = torch.full(
|
|
(audio_logits.shape[0],),
|
|
self.config["audio_out_token_idx"],
|
|
dtype=torch.long,
|
|
device=device,
|
|
)
|
|
|
|
# Handle delay_pattern
|
|
if self.use_delay_pattern:
|
|
if num_delay + 1 < next_audio_tokens.shape[0]:
|
|
next_audio_tokens[(num_delay + 1) :] = self.config["audio_stream_bos_id"]
|
|
num_delay += 1
|
|
if num_remaining_delays is not None:
|
|
next_audio_tokens[: (self.audio_num_codebooks - num_remaining_delays)] = (
|
|
self.config["audio_stream_eos_id"]
|
|
)
|
|
num_remaining_delays -= 1
|
|
else:
|
|
all_eos_indices = (next_audio_tokens == self.config["audio_stream_eos_id"]).nonzero()
|
|
if torch.numel(all_eos_indices) > 0:
|
|
all_eos_indices = all_eos_indices[0]
|
|
last_eos_idx = all_eos_indices[-1]
|
|
next_audio_tokens[:last_eos_idx] = self.config["audio_stream_eos_id"]
|
|
num_remaining_delays = self.audio_num_codebooks - last_eos_idx - 1
|
|
if num_remaining_delays is not None and num_remaining_delays <= 0:
|
|
next_tokens[...] = audio_eos_token_id
|
|
num_delay = 0
|
|
num_remaining_delays = None
|
|
|
|
return (
|
|
next_tokens,
|
|
next_audio_tokens,
|
|
num_delay,
|
|
num_remaining_delays,
|
|
)
|
|
|
|
def _sample_text_tokens(
|
|
self,
|
|
logits: torch.Tensor,
|
|
input_ids: torch.Tensor,
|
|
logits_processing_list,
|
|
device: torch.device,
|
|
generation_mode: GenerationMode,
|
|
torch_generator,
|
|
is_using_cuda_graphs,
|
|
do_sample = False,
|
|
) -> torch.Tensor:
|
|
"""Sample text tokens from the logits"""
|
|
|
|
next_token_logits = logits.clone()[:, -1, :].float()
|
|
next_token_logits = next_token_logits.to(input_ids.device)
|
|
|
|
# pre-process distribution
|
|
next_token_scores = apply_logits_processing(input_ids, next_token_logits, logits_processing_list)
|
|
|
|
if generation_mode == GenerationMode.AUDIO_INIT:
|
|
# See the audio bos token, we should start generating audio tokens
|
|
next_tokens = torch.full(
|
|
(input_ids.shape[0],),
|
|
self.audio_out_token_idx,
|
|
dtype=torch.long,
|
|
device=device,
|
|
)
|
|
next_audio_tokens = torch.full(
|
|
(self.config["audio_num_codebooks"],),
|
|
self.config["audio_stream_bos_id"],
|
|
dtype=torch.long,
|
|
device=device,
|
|
)
|
|
else:
|
|
|
|
if do_sample:
|
|
probs = nn.functional.softmax(next_token_scores, dim = -1)
|
|
# same as for audio
|
|
if not is_using_cuda_graphs:
|
|
next_tokens = torch.multinomial(probs, num_samples = 1, generator=torch_generator).squeeze(1)
|
|
else:
|
|
next_tokens = categorical_sample(probs, generator = torch_generator)
|
|
else:
|
|
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
|
|
|
next_audio_tokens = None
|
|
|
|
return next_tokens, next_audio_tokens
|
|
|
|
def _update_causal_mask(
|
|
self,
|
|
attention_mask: torch.Tensor,
|
|
input_tensor: torch.Tensor,
|
|
cache_position: torch.Tensor,
|
|
past_key_values: Cache,
|
|
output_attentions: bool,
|
|
):
|
|
if self.attn_implementation == "attention_flash":
|
|
if attention_mask is not None and 0.0 in attention_mask:
|
|
return attention_mask
|
|
return None
|
|
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
|
|
|
if self.attn_implementation == "attention_pytorch" and not using_static_cache and not output_attentions:
|
|
if _ignore_causal_mask_sdpa(
|
|
attention_mask,
|
|
inputs_embeds=input_tensor,
|
|
past_key_values_length=past_seen_tokens,
|
|
is_training=self.training,
|
|
):
|
|
return None
|
|
|
|
dtype, device = input_tensor.dtype, input_tensor.device
|
|
min_dtype = torch.finfo(dtype).min
|
|
sequence_length = input_tensor.shape[1]
|
|
if using_static_cache:
|
|
if hasattr(past_key_values, "get_max_length"):
|
|
target_length = past_key_values.get_max_length()
|
|
else:
|
|
target_length = past_key_values.max_cache_len
|
|
else:
|
|
target_length = (
|
|
attention_mask.shape[-1]
|
|
if isinstance(attention_mask, torch.Tensor)
|
|
else past_seen_tokens + sequence_length + 1
|
|
)
|
|
|
|
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
|
attention_mask,
|
|
sequence_length=sequence_length,
|
|
target_length=target_length,
|
|
dtype=dtype,
|
|
device=device,
|
|
min_dtype=min_dtype,
|
|
cache_position=cache_position,
|
|
batch_size=input_tensor.shape[0],
|
|
)
|
|
|
|
if (
|
|
self.attn_implementation == "attention_pytorch"
|
|
and attention_mask is not None
|
|
and attention_mask.device.type == "cuda"
|
|
and not output_attentions
|
|
):
|
|
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True))
|
|
|
|
return causal_mask
|
|
|
|
def _embed_audio_ids(self, audio_ids):
|
|
codebook_shift = (
|
|
torch.arange(self.config["audio_num_codebooks"], device=audio_ids.device) * self.audio_codebook_size
|
|
)
|
|
audio_embed = self.audio_codebook_embeddings(audio_ids + codebook_shift.unsqueeze(-1))
|
|
audio_embed = torch.sum(audio_embed, dim=0)
|
|
return audio_embed
|
|
|
|
def _prepare_all_static_kv_cache_masks(self, hidden_states, attention_mask, audio_out_mask, past_key_values):
|
|
target_length = hidden_states.shape[1]
|
|
cur_pos = audio_out_mask.shape[1]
|
|
min_dtype = torch.finfo(hidden_states.dtype).min
|
|
assert len(attention_mask.shape) == 4, "Only support SDPA for now"
|
|
kv_cache_len = past_key_values.get_max_cache_shape()
|
|
audio_out_mask_padded = torch.nn.functional.pad(audio_out_mask, (0, kv_cache_len - cur_pos), value=True)
|
|
fast_forward_attention_mask = attention_mask.masked_fill(
|
|
audio_out_mask_padded[:, audio_out_mask.shape[1] - target_length : audio_out_mask.shape[1]].reshape(
|
|
audio_out_mask_padded.shape[0], 1, target_length, 1
|
|
)
|
|
| audio_out_mask_padded.reshape(audio_out_mask_padded.shape[0], 1, 1, audio_out_mask_padded.shape[1]),
|
|
min_dtype,
|
|
)
|
|
|
|
no_audio_out_mask = ~audio_out_mask
|
|
no_audio_out_mask = torch.nn.functional.pad(
|
|
no_audio_out_mask, (0, kv_cache_len - audio_out_mask.shape[1]), value=False
|
|
)
|
|
no_audio_out_mask = no_audio_out_mask[
|
|
:, audio_out_mask.shape[1] - target_length : audio_out_mask.shape[1]
|
|
].reshape(audio_out_mask.shape[0], 1, target_length, 1) | no_audio_out_mask.reshape(
|
|
audio_out_mask.shape[0], 1, 1, kv_cache_len
|
|
)
|
|
audio_attention_mask = attention_mask.masked_fill(no_audio_out_mask, min_dtype)
|
|
return fast_forward_attention_mask, audio_attention_mask
|
|
|
|
def _forward_core(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
causal_mask: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
audio_discrete_codes_mask: torch.Tensor,
|
|
cache_position: torch.Tensor,
|
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]],
|
|
use_cache: bool,
|
|
audio_attention_mask: torch.Tensor,
|
|
fast_forward_attention_mask: torch.Tensor,
|
|
is_decoding_audio_token: Optional[bool] = None,
|
|
is_using_cuda_graph: Optional[bool] = False,
|
|
):
|
|
|
|
position_id_offset = cache_position[0] if use_cache else 0
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids + position_id_offset)
|
|
|
|
for decoder_layer in self.layers:
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=causal_mask,
|
|
audio_attention_mask=audio_attention_mask,
|
|
fast_forward_attention_mask=fast_forward_attention_mask,
|
|
position_ids=position_ids,
|
|
audio_out_mask=audio_discrete_codes_mask,
|
|
is_decoding_audio_token=is_decoding_audio_token,
|
|
past_key_value=past_key_values,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
position_embeddings=position_embeddings,
|
|
is_using_cuda_graph=is_using_cuda_graph,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
return hidden_states
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.BoolTensor] = None,
|
|
audio_in_ids: Optional[torch.LongTensor] = None,
|
|
audio_in_ids_start: Optional[torch.LongTensor] = None,
|
|
audio_out_ids: Optional[torch.LongTensor] = None,
|
|
audio_out_ids_start: Optional[torch.LongTensor] = None,
|
|
label_ids: Optional[torch.LongTensor] = None,
|
|
label_audio_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_audio_hidden_states: Optional[bool] = False,
|
|
return_dict: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
cache_audio_discrete_codes_mask: Optional[torch.LongTensor] = None,
|
|
past_key_values_buckets: Optional[OrderedDict[int, Cache]] = None,
|
|
is_using_cuda_graphs: bool = None,
|
|
**kwargs
|
|
):
|
|
|
|
target_device = input_ids.device
|
|
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
if self.config["encode_audio_in_tokens"]:
|
|
if audio_in_ids is not None and audio_in_ids.shape[-1] > 0:
|
|
audio_in_ids = audio_in_ids.to(target_device)
|
|
else:
|
|
audio_in_ids = torch.zeros((self.audio_num_codebooks, 0), device=target_device, dtype=torch.long)
|
|
audio_in_embed = self._embed_audio_ids(audio_in_ids)
|
|
else:
|
|
audio_in_embed = None
|
|
|
|
if audio_out_ids is not None and audio_out_ids.shape[-1] > 0:
|
|
audio_out_ids = audio_out_ids.to(target_device)
|
|
else:
|
|
audio_out_ids = torch.zeros((self.audio_num_codebooks, 0), device=target_device, dtype=torch.long)
|
|
audio_out_embed = self._embed_audio_ids(audio_out_ids)
|
|
|
|
round_to = 1 if use_cache else 8
|
|
left_padding = True if use_cache or input_ids.shape[0] == 1 else False
|
|
(
|
|
inputs_embeds,
|
|
attention_mask,
|
|
labels,
|
|
position_ids,
|
|
input_ids,
|
|
audio_in_mask,
|
|
audio_in_discrete_codes_mask,
|
|
audio_out_mask,
|
|
) = merge_input_ids_with_audio_features(
|
|
audio_in_embed,
|
|
audio_in_ids_start,
|
|
audio_out_embed,
|
|
audio_out_ids_start,
|
|
self.audio_in_token_idx,
|
|
self.audio_out_token_idx,
|
|
inputs_embeds,
|
|
input_ids,
|
|
attention_mask,
|
|
label_ids,
|
|
pad_token_id=self.padding_idx,
|
|
round_to=round_to,
|
|
left_padding=left_padding,
|
|
)
|
|
|
|
# re-check if we use the correct kv cache bucket after
|
|
# the input_embeds has been merged with audio features
|
|
if past_key_values_buckets is not None and inputs_embeds.shape[1] > past_key_values.get_max_cache_shape():
|
|
past_key_values, self.current_past_key_values_bucket = self._prepare_kv_cache(
|
|
inputs_embeds.shape[1], None, past_key_values_buckets
|
|
)
|
|
|
|
if use_cache and past_key_values is None:
|
|
past_key_values = DynamicCache()
|
|
|
|
if cache_position is None:
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
cache_position = torch.arange(
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
)
|
|
if isinstance(past_key_values, StaticCache) and past_seen_tokens >= past_key_values.get_max_cache_shape():
|
|
raise ValueError(
|
|
f"The current sequence length ({past_seen_tokens}) exceeds "
|
|
f"the maximum cache shape. "
|
|
f"Please consider increasing the cache size."
|
|
)
|
|
|
|
# Use torch compile
|
|
use_static_cache = isinstance(past_key_values, StaticCache)
|
|
|
|
# Apply the LLM component
|
|
causal_mask = self._update_causal_mask(
|
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
|
)
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
audio_discrete_codes_mask = audio_in_discrete_codes_mask | audio_out_mask
|
|
if cache_audio_discrete_codes_mask is not None and use_cache:
|
|
audio_discrete_codes_mask = torch.concat(
|
|
[cache_audio_discrete_codes_mask, audio_discrete_codes_mask], dim=1
|
|
)
|
|
|
|
# Generate the audio attention mask outside the layer to avoid recompilation
|
|
if use_static_cache:
|
|
fast_forward_attention_mask, audio_attention_mask = self._prepare_all_static_kv_cache_masks(
|
|
hidden_states, causal_mask, audio_discrete_codes_mask, past_key_values
|
|
)
|
|
# Set the audio out mask to the last token
|
|
if hidden_states.shape[1] == 1:
|
|
audio_discrete_codes_mask = audio_discrete_codes_mask[:, -1:]
|
|
audio_discrete_codes_mask = audio_discrete_codes_mask.reshape((-1, 1)).contiguous()
|
|
is_decoding_audio_token = audio_discrete_codes_mask.item()
|
|
else:
|
|
is_decoding_audio_token = False
|
|
|
|
if (
|
|
past_key_values is not None
|
|
and is_using_cuda_graphs
|
|
and past_key_values.get_max_cache_shape() in self.decode_graph_runners
|
|
and (input_ids.shape[-1] == 1)
|
|
):
|
|
_forward_core = self.decode_graph_runners[past_key_values.get_max_cache_shape()][is_decoding_audio_token]
|
|
local_cuda_graph = True
|
|
else:
|
|
_forward_core = self._forward_core
|
|
local_cuda_graph = False
|
|
|
|
hidden_states = _forward_core(
|
|
hidden_states=hidden_states,
|
|
causal_mask=causal_mask,
|
|
position_ids=position_ids,
|
|
audio_discrete_codes_mask=audio_discrete_codes_mask,
|
|
is_decoding_audio_token=is_decoding_audio_token if use_static_cache else None,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
audio_attention_mask=audio_attention_mask if use_static_cache else None,
|
|
fast_forward_attention_mask=fast_forward_attention_mask if use_static_cache else None,
|
|
is_using_cuda_graph = local_cuda_graph,
|
|
)
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
logits, audio_logits = (
|
|
self.audio_decoder_proj(
|
|
hidden_states,
|
|
audio_out_mask,
|
|
label_audio_ids=label_audio_ids,
|
|
attention_mask=causal_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
output_audio_hidden_states=output_audio_hidden_states,
|
|
cache_position=cache_position,
|
|
)
|
|
)
|
|
|
|
if audio_logits is not None:
|
|
audio_logits = audio_logits.view(
|
|
audio_logits.shape[0], self.audio_num_codebooks, self.audio_codebook_size
|
|
).float()
|
|
|
|
next_cache = past_key_values if use_cache else None
|
|
|
|
ret = HiggsAudioModelOutputWithPast(
|
|
logits=logits,
|
|
audio_logits=audio_logits,
|
|
past_key_values=next_cache,
|
|
audio_out_mask = audio_out_mask,
|
|
audio_in_discrete_codes_mask = audio_in_discrete_codes_mask
|
|
)
|
|
|
|
return ret
|
|
|
|
def _update_model_kwargs_for_generation(
|
|
self,
|
|
outputs,
|
|
model_kwargs,
|
|
):
|
|
model_kwargs["past_key_values"] = outputs.past_key_values
|
|
|
|
# update attention mask
|
|
if "attention_mask" in model_kwargs:
|
|
attention_mask = model_kwargs["attention_mask"]
|
|
model_kwargs["attention_mask"] = torch.cat(
|
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
|
)
|
|
if "cache_audio_discrete_codes_mask" in model_kwargs:
|
|
if model_kwargs["cache_audio_discrete_codes_mask"] is None:
|
|
model_kwargs["cache_audio_discrete_codes_mask"] = (
|
|
outputs.audio_in_discrete_codes_mask | outputs.audio_out_mask
|
|
)
|
|
else:
|
|
model_kwargs["cache_audio_discrete_codes_mask"] = torch.concat(
|
|
[
|
|
model_kwargs["cache_audio_discrete_codes_mask"],
|
|
outputs.audio_in_discrete_codes_mask | outputs.audio_out_mask,
|
|
],
|
|
1,
|
|
)
|
|
|
|
return model_kwargs
|
|
|
|
def _copy_kv_cache(self, from_cache: Cache, to_cache: Cache):
|
|
from_cache_size = from_cache.get_max_cache_shape()
|
|
assert to_cache.get_max_cache_shape() >= from_cache_size, (
|
|
f"The target cache size {to_cache.get_max_cache_shape()} is smaller than the source cache size {from_cache_size}."
|
|
)
|
|
|
|
n_layers = self.num_hidden_layers
|
|
|
|
for i in range(n_layers):
|
|
from_layer = from_cache.layers[i]
|
|
to_layer = to_cache.layers[i]
|
|
|
|
seq_len = from_cache_size
|
|
|
|
to_layer.lazy_init(from_layer.keys)
|
|
to_layer.keys[:, :, :seq_len, :] = from_layer.keys
|
|
to_layer.values[:, :, :seq_len, :] = from_layer.values
|
|
|
|
def _prepare_kv_cache(
|
|
self,
|
|
current_sequence_length: int,
|
|
current_past_key_values_bucket: Optional[int],
|
|
past_key_values_buckets: OrderedDict[int, Cache],
|
|
) -> Tuple[Optional[Cache], Optional[int]]:
|
|
|
|
for cache_length in past_key_values_buckets.keys():
|
|
if cache_length >= current_sequence_length:
|
|
|
|
if current_past_key_values_bucket is not None and cache_length != current_past_key_values_bucket:
|
|
self._copy_kv_cache(
|
|
past_key_values_buckets[current_past_key_values_bucket], past_key_values_buckets[cache_length]
|
|
)
|
|
|
|
return past_key_values_buckets[cache_length], cache_length
|
|
|
|
raise ValueError(
|
|
f"The current sequence length {current_sequence_length} is larger than "
|
|
f"all past key values buckets {past_key_values_buckets.keys()}."
|
|
)
|
|
|
|
def _sample(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
logits_processing_list,
|
|
generation_config: GenerationConfig,
|
|
past_key_values_buckets: Optional[OrderedDict[int, Cache]],
|
|
**model_kwargs,
|
|
):
|
|
|
|
# code supports only non-mixed batchs
|
|
|
|
audio_out_bos_token_id = generation_config.generation_kwargs.get("audio_out_bos_token_id", None)
|
|
|
|
# torch generator for sampling
|
|
torch_generator = model_kwargs.pop("torch_generator", None)
|
|
# pbar for sampling
|
|
pbar = model_kwargs.pop("pbar", None)
|
|
|
|
# init values
|
|
pad_token_id = generation_config.pad_token_id
|
|
# Used to track which past_key_va
|
|
self.current_past_key_values_bucket = None
|
|
max_length = generation_config.max_length
|
|
|
|
# keep track of which sequences are already finished
|
|
batch_size, cur_len = input_ids.shape
|
|
this_peer_finished = False
|
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.bool, device=input_ids.device)
|
|
if generation_config.use_cache:
|
|
model_kwargs["cache_audio_discrete_codes_mask"] = None
|
|
|
|
do_sample = generation_config.do_sample
|
|
is_using_cuda_graphs = generation_config.is_using_cuda_graphs
|
|
|
|
init_model_input = True
|
|
num_delay = 0
|
|
num_remaining_delays = None
|
|
audio_sequences = []
|
|
# A tensor to keep track of all the audio placeholder tokens.
|
|
input_ids_full = input_ids.clone()
|
|
|
|
# Initialize the audio variables based on the input prompt.
|
|
if input_ids[0][-1] == self.config["audio_out_token_idx"]:
|
|
audio_sequences = [model_kwargs["audio_out_ids"][:, model_kwargs["audio_out_ids_start"][-1] :]]
|
|
if self.use_delay_pattern:
|
|
num_delay = (
|
|
self.audio_num_codebooks
|
|
- (model_kwargs["audio_out_ids"][:, -1] == self.config["audio_stream_bos_id"]).sum()
|
|
)
|
|
all_eos_indices = (model_kwargs["audio_out_ids"][:, -1] == self.config['audio_stream_eos_id']).nonzero()
|
|
if torch.numel(all_eos_indices) > 0:
|
|
all_eos_indices = all_eos_indices[0]
|
|
last_eos_idx = all_eos_indices[-1]
|
|
num_remaining_delays = self.audio_num_codebooks - last_eos_idx - 1
|
|
|
|
while not this_peer_finished:
|
|
eos_token_tensor = torch.tensor([self.config["text_config"]["eos_token_id"]], device=input_ids.device)
|
|
|
|
if input_ids[0][-1] == audio_out_bos_token_id:
|
|
generation_mode = GenerationMode.AUDIO_INIT
|
|
elif input_ids[0][-1] == self.audio_out_token_idx:
|
|
generation_mode = GenerationMode.AUDIO_IN_PROGRESS
|
|
eos_token_tensor = torch.tensor([self.config["audio_eos_token_id"]], device=input_ids.device)
|
|
else:
|
|
generation_mode = GenerationMode.TEXT
|
|
|
|
is_audio_generation_mode = generation_mode == GenerationMode.AUDIO_IN_PROGRESS
|
|
|
|
if init_model_input or not generation_config.use_cache:
|
|
model_inputs = {"input_ids": input_ids, **model_kwargs}
|
|
else:
|
|
model_inputs = {"input_ids": input_ids[:, -1:], **model_kwargs}
|
|
|
|
if is_audio_generation_mode and generation_config.use_cache:
|
|
model_inputs["audio_out_ids"] = model_kwargs["audio_out_ids"][:, -1:]
|
|
model_inputs["audio_out_ids_start"] = torch.tensor([0], dtype=torch.long, device=input_ids.device)
|
|
elif not is_audio_generation_mode:
|
|
del model_inputs["audio_out_ids"]
|
|
del model_inputs["audio_out_ids_start"]
|
|
|
|
if generation_config.use_cache:
|
|
if "audio_in_ids" in model_inputs and model_inputs["audio_in_ids"] is not None:
|
|
model_inputs["audio_in_ids"] = None
|
|
model_inputs["audio_in_ids_start"] = None
|
|
|
|
if past_key_values_buckets is not None:
|
|
past_key_values, self.current_past_key_values_bucket = self._prepare_kv_cache(
|
|
cur_len, self.current_past_key_values_bucket, past_key_values_buckets
|
|
)
|
|
if past_key_values is not None:
|
|
model_inputs.update({"past_key_values": past_key_values})
|
|
model_inputs["past_key_values_buckets"] = past_key_values_buckets
|
|
|
|
outputs = self(**model_inputs, is_using_cuda_graphs = is_using_cuda_graphs, return_dict=True)
|
|
|
|
# Update the actual sequence length after the first forward pass
|
|
if init_model_input and past_key_values_buckets is not None:
|
|
cur_len = past_key_values_buckets[self.current_past_key_values_bucket].get_seq_length().item()
|
|
|
|
model_kwargs = self._update_model_kwargs_for_generation(
|
|
outputs,
|
|
model_kwargs,
|
|
)
|
|
|
|
init_model_input = False
|
|
|
|
if this_peer_finished:
|
|
continue
|
|
|
|
if is_audio_generation_mode:
|
|
(
|
|
next_tokens,
|
|
next_audio_tokens,
|
|
num_delay,
|
|
num_remaining_delays,
|
|
) = self._sample_audio_tokens(
|
|
audio_logits=outputs.audio_logits,
|
|
audio_out_ids=model_kwargs["audio_out_ids"],
|
|
logits_processing_list=logits_processing_list,
|
|
device=input_ids.device,
|
|
torch_generator=torch_generator,
|
|
generation_config=generation_config,
|
|
num_delay=num_delay,
|
|
num_remaining_delays=num_remaining_delays,
|
|
do_sample = do_sample,
|
|
is_using_cuda_graphs = is_using_cuda_graphs
|
|
)
|
|
|
|
# update generated ids, model inputs, and length for next step
|
|
model_kwargs["audio_out_ids"] = torch.cat(
|
|
[model_kwargs["audio_out_ids"], next_audio_tokens[:, None]], dim=-1
|
|
)
|
|
audio_sequences[-1] = torch.cat([audio_sequences[-1], next_audio_tokens[:, None]], dim=-1)
|
|
|
|
else:
|
|
next_tokens, next_audio_tokens = self._sample_text_tokens(
|
|
input_ids=input_ids,
|
|
logits=outputs.logits,
|
|
logits_processing_list=logits_processing_list,
|
|
device=input_ids.device,
|
|
generation_mode=generation_mode,
|
|
torch_generator = torch_generator,
|
|
do_sample = do_sample,
|
|
is_using_cuda_graphs = is_using_cuda_graphs
|
|
)
|
|
|
|
if next_audio_tokens is not None:
|
|
audio_sequences.append(next_audio_tokens[:, None])
|
|
if model_kwargs["audio_out_ids"] is None or model_kwargs["audio_out_ids"].shape[0] == 0:
|
|
model_kwargs["audio_out_ids"] = next_audio_tokens[:, None]
|
|
model_kwargs["audio_out_ids_start"] = torch.tensor(
|
|
[0], dtype=torch.long, device=input_ids.device
|
|
)
|
|
else:
|
|
model_kwargs["audio_out_ids_start"] = torch.concat(
|
|
[
|
|
model_kwargs["audio_out_ids_start"],
|
|
torch.tensor(
|
|
[model_kwargs["audio_out_ids"].shape[1]], dtype=torch.long, device=input_ids.device
|
|
),
|
|
],
|
|
dim=0,
|
|
)
|
|
model_kwargs["audio_out_ids"] = torch.concat(
|
|
[model_kwargs["audio_out_ids"], next_audio_tokens[:, None]], dim=1
|
|
)
|
|
|
|
# finished sentences should have their next token be a padding token
|
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (~unfinished_sequences)
|
|
|
|
if "tokenizer_length" in generation_config.generation_kwargs:
|
|
tokenizer_length = generation_config.generation_kwargs["tokenizer_length"]
|
|
if torch.max(next_tokens) >= tokenizer_length:
|
|
raise ValueError(
|
|
f"Next generated token has max value {torch.max(next_tokens)} which is greater than the tokenizer's vocabulary size {tokenizer_length}, this is undesired behavior."
|
|
)
|
|
|
|
if pbar is not None:
|
|
pbar.update(1)
|
|
|
|
if not is_audio_generation_mode or next_tokens[0] != self.audio_out_token_idx:
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
|
|
input_ids_full = torch.cat([input_ids_full, next_tokens[:, None]], dim=-1)
|
|
finished, unfinished_sequences = check_stopping_criteria(input_ids_full, max_length, eos_token = eos_token_tensor, stop_strings = self.stop_strings)
|
|
this_peer_finished = finished.all()
|
|
cur_len += 1
|
|
|
|
del outputs
|
|
torch.cuda.empty_cache()
|
|
|
|
if pbar is not None:
|
|
if pbar.total != pbar.current:
|
|
pbar.update(pbar.total - pbar.current)
|
|
|
|
return audio_sequences
|
|
|
|
@torch.inference_mode()
|
|
def generate(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
audio_out_bos_token_id: int = None,
|
|
audio_eos_token_id: int = None,
|
|
generation_config = None,
|
|
generation_functions = None,
|
|
**kwargs,
|
|
):
|
|
|
|
if generation_config is None:
|
|
generation_config = GenerationConfig()
|
|
|
|
generation_config, kwargs = generation_functions._prepare_generation_config(generation_config, **kwargs)
|
|
if audio_out_bos_token_id is not None:
|
|
generation_config.generation_kwargs["audio_out_bos_token_id"] = audio_out_bos_token_id
|
|
else:
|
|
try:
|
|
generation_config.generation_kwargs["audio_out_bos_token_id"] = self.audio_out_bos_token_id
|
|
except:
|
|
generation_config.generation_kwargs["audio_out_bos_token_id"] = None
|
|
|
|
if audio_eos_token_id is not None:
|
|
generation_config.generation_kwargs["audio_eos_token_id"] = audio_eos_token_id
|
|
else:
|
|
try:
|
|
generation_config.generation_kwargs["audio_eos_token_id"] = self.audio_eos_token_id
|
|
except:
|
|
generation_config.generation_kwargs["audio_eos_token_id"] = None
|
|
|
|
generation_config.generation_kwargs["ras_win_len"] = 7
|
|
generation_config.generation_kwargs["ras_win_max_num_repeat"] = kwargs.pop("ras_win_max_num_repeat", 2)
|
|
|
|
if "tokenizer" in kwargs:
|
|
generation_config.generation_kwargs["tokenizer_length"] = len(kwargs["tokenizer"])
|
|
|
|
input_ids_length = input_ids.shape[-1]
|
|
generation_config = generation_functions._prepare_generated_length(
|
|
generation_config=generation_config,
|
|
input_ids_length=input_ids_length,
|
|
)
|
|
|
|
return generation_config
|