"""Higgs-Audio is an end-to-end multimodal model with the capability to understand and generate text / audio.""" import torch import torch.nn as nn from transformers.models.auto import CONFIG_MAPPING import math import glob import functools import os from collections import defaultdict, OrderedDict from dataclasses import dataclass from enum import Enum from safetensors.torch import load_file from typing import Optional, Tuple, Union, List, Dict, Any from transformers import AutoTokenizer from transformers.modeling_outputs import BaseModelOutput from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessorList, StoppingCriteriaList from transformers.generation.utils import GenerateNonBeamOutput from transformers.utils import logging, ModelOutput from transformers.configuration_utils import PretrainedConfig from torch import nn def _ceil_to_nearest(n, round_to): return (n + round_to - 1) // round_to * round_to from transformers.modeling_utils import PreTrainedModel class HiggsAudioConfig(PretrainedConfig): model_type = "higgs_audio" is_composition = True def __init__( self, text_config=None, audio_encoder_config=None, audio_tokenizer_config=None, audio_adapter_type="stack", audio_embed_avg=False, audio_ffn_hidden_size=4096, audio_ffn_intermediate_size=14336, audio_dual_ffn_layers=None, audio_decoder_proj_num_layers=0, encode_whisper_embed=True, encode_audio_in_tokens=False, use_delay_pattern=False, skip_audio_tower=False, use_audio_out_embed_projector=False, use_audio_out_self_attention=False, use_rq_transformer=False, rq_transformer_hidden_size=None, rq_transformer_intermediate_size=None, rq_transformer_num_attention_heads=None, rq_transformer_num_key_value_heads=None, rq_transformer_num_hidden_layers=3, audio_num_codebooks=12, audio_codebook_size=1024, audio_stream_bos_id=1024, audio_stream_eos_id=1025, audio_bos_token="<|audio_bos|>", audio_eos_token="<|audio_eos|>", audio_out_bos_token="<|audio_out_bos|>", audio_in_token="<|AUDIO|>", audio_out_token="<|AUDIO_OUT|>", audio_in_token_idx=128015, audio_out_token_idx=128016, pad_token_id=128001, audio_out_bos_token_id=128013, audio_eos_token_id=128012, **kwargs, ): if isinstance(text_config, dict): text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) elif text_config is None: text_config = CONFIG_MAPPING["llama"]() assert audio_adapter_type in [ "stack", "dual_ffn", "dual_ffn_fast_forward", ], f"Invalid audio adapter type: {audio_adapter_type}" if audio_adapter_type.startswith("dual_ffn"): assert audio_dual_ffn_layers is not None, ( "audio_dual_ffn_layers must be specified when using dual_ffn adapter." ) self.text_config = text_config self.audio_encoder_config = audio_encoder_config self.audio_tokenizer_config = audio_tokenizer_config self.audio_adapter_type = audio_adapter_type self.audio_embed_avg = audio_embed_avg self.audio_ffn_hidden_size = audio_ffn_hidden_size self.audio_ffn_intermediate_size = audio_ffn_intermediate_size self.audio_dual_ffn_layers = audio_dual_ffn_layers self.audio_decoder_proj_num_layers = audio_decoder_proj_num_layers self.encode_whisper_embed = encode_whisper_embed self.encode_audio_in_tokens = encode_audio_in_tokens self.use_delay_pattern = use_delay_pattern self.skip_audio_tower = skip_audio_tower self.use_audio_out_embed_projector = use_audio_out_embed_projector self.use_audio_out_self_attention = use_audio_out_self_attention self.use_rq_transformer = use_rq_transformer if self.use_rq_transformer: assert not self.use_delay_pattern, "Delay pattern is not supported if you turned on RQ-Transformer!" self.rq_transformer_hidden_size = rq_transformer_hidden_size self.rq_transformer_intermediate_size = rq_transformer_intermediate_size self.rq_transformer_num_attention_heads = rq_transformer_num_attention_heads self.rq_transformer_num_key_value_heads = rq_transformer_num_key_value_heads self.rq_transformer_num_hidden_layers = rq_transformer_num_hidden_layers if use_rq_transformer: # For RQ-Transformer, we set the hidden_size to the same as the text model's hidden size if it is not specified. if self.rq_transformer_hidden_size is None: self.rq_transformer_hidden_size = text_config.hidden_size assert self.rq_transformer_hidden_size % 128 == 0 if self.rq_transformer_intermediate_size is None: self.rq_transformer_intermediate_size = text_config.intermediate_size if self.rq_transformer_num_attention_heads is None: self.rq_transformer_num_attention_heads = self.rq_transformer_hidden_size // 128 if self.rq_transformer_num_key_value_heads is None: self.rq_transformer_num_key_value_heads = self.rq_transformer_hidden_size // 128 // 4 assert self.rq_transformer_hidden_size % self.rq_transformer_num_attention_heads == 0 assert self.rq_transformer_hidden_size % self.rq_transformer_num_key_value_heads == 0 self.audio_num_codebooks = audio_num_codebooks self.audio_codebook_size = audio_codebook_size self.audio_bos_token = audio_bos_token self.audio_eos_token = audio_eos_token self.audio_out_bos_token = audio_out_bos_token self.audio_in_token = audio_in_token self.audio_out_token = audio_out_token self.audio_in_token_idx = audio_in_token_idx self.audio_out_token_idx = audio_out_token_idx self.audio_stream_bos_id = audio_stream_bos_id self.audio_stream_eos_id = audio_stream_eos_id self.audio_out_bos_token_id = audio_out_bos_token_id self.audio_eos_token_id = audio_eos_token_id super().__init__(**kwargs) self.pad_token_id = pad_token_id class HiggsAudioPreTrainedModel(PreTrainedModel): config_class = HiggsAudioConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = [] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True def merge_input_ids_with_audio_features( audio_in_embed, audio_in_ids_start, audio_out_embed, audio_out_ids_start, audio_in_token_idx, audio_out_token_idx, inputs_embeds, input_ids, attention_mask, label_ids, pad_token_id, ignore_index=-100, round_to=8, left_padding=True, ): def compute_audio_codes_length(ids_start, embed): return torch.concat([ ids_start[1:] - ids_start[:-1], torch.tensor([embed.shape[0] - ids_start[-1]], device=ids_start.device, dtype=torch.long), ], dim=0).long() def fill_audio_embeddings(final_embedding, final_input_ids, final_labels, final_audio_mask, embed, token_idx, ids_start, codes_length, token_ends, batch_id, skip_labels, ignore_index): seq_indices = torch.arange(max_token_num, device=target_device).unsqueeze(0).expand(ids_start.shape[0], max_token_num) token_starts = token_ends - codes_length + 1 batch_indices, col_indices = torch.where((seq_indices >= token_starts.unsqueeze(1)) & (seq_indices <= token_ends.unsqueeze(1))) batch_indices = batch_id[batch_indices] final_embedding[batch_indices, col_indices] = embed final_input_ids[batch_indices, col_indices] = token_idx if not skip_labels: final_labels[batch_indices, col_indices] = ignore_index final_audio_mask[batch_indices, col_indices] = True skip_labels = label_ids is None if audio_in_embed is not None and audio_in_embed.shape[0] == 0: audio_in_embed = None if audio_out_embed is not None and audio_out_embed.shape[0] == 0: audio_out_embed = None batch_size, sequence_length, embed_dim = inputs_embeds.shape target_device = inputs_embeds.device if left_padding is None: left_padding = torch.any(attention_mask[:, 0] == 0) audio_in_token_mask, audio_out_token_mask = input_ids == audio_in_token_idx, input_ids == audio_out_token_idx text_token_mask = (input_ids != audio_in_token_idx) & (input_ids != audio_out_token_idx) token_placeholder_num = torch.ones_like(input_ids) if audio_in_embed is not None: audio_in_codes_length = compute_audio_codes_length(audio_in_ids_start, audio_in_embed) token_placeholder_num[audio_in_token_mask] = audio_in_codes_length if audio_out_embed is not None: audio_out_codes_length = compute_audio_codes_length(audio_out_ids_start, audio_out_embed) token_placeholder_num[audio_out_token_mask] = audio_out_codes_length new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1 max_token_num = _ceil_to_nearest(token_placeholder_num.sum(-1).max(), round_to) nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1] if left_padding: new_token_positions += nb_audio_pad[:, None] final_embedding = torch.zeros((batch_size, max_token_num, embed_dim), dtype=inputs_embeds.dtype, device=target_device) final_attention_mask = torch.zeros((batch_size, max_token_num), dtype=attention_mask.dtype, device=target_device) final_input_ids = torch.full((batch_size, max_token_num), pad_token_id, dtype=input_ids.dtype, device=target_device) final_labels = None if skip_labels else torch.full((batch_size, max_token_num), ignore_index, dtype=label_ids.dtype, device=target_device) final_audio_in_mask = torch.zeros((batch_size, max_token_num), dtype=torch.bool, device=target_device) final_audio_in_discrete_codes_mask = torch.zeros((batch_size, max_token_num), dtype=torch.bool, device=target_device) final_audio_out_mask = torch.zeros((batch_size, max_token_num), dtype=torch.bool, device=target_device) batch_id = torch.arange(batch_size, device=target_device).unsqueeze(1).expand(batch_size, sequence_length) audio_in_batch_id, audio_out_batch_id = batch_id[audio_in_token_mask], batch_id[audio_out_token_mask] audio_in_token_ends, audio_out_token_ends = new_token_positions[audio_in_token_mask], new_token_positions[audio_out_token_mask] if audio_in_embed is not None: fill_audio_embeddings(final_embedding, final_input_ids, final_labels, final_audio_in_mask, audio_in_embed, audio_in_token_idx, audio_in_ids_start, audio_in_codes_length, audio_in_token_ends, audio_in_batch_id, skip_labels, ignore_index) final_audio_in_discrete_codes_mask = final_audio_in_mask.clone() if audio_out_embed is not None: fill_audio_embeddings(final_embedding, final_input_ids, final_labels, final_audio_out_mask, audio_out_embed, audio_out_token_idx, audio_out_ids_start, audio_out_codes_length, audio_out_token_ends, audio_out_batch_id, skip_labels, ignore_index) batch_indices, text_indices = torch.where(text_token_mask) text_to_overwrite = new_token_positions[batch_indices, text_indices] final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, text_indices] if not skip_labels: final_labels[batch_indices, text_to_overwrite] = label_ids[batch_indices, text_indices] final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, text_indices] final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, text_indices] final_attention_mask |= final_audio_in_mask | final_audio_out_mask if left_padding: first_non_zero_loc = (final_attention_mask.sum(0).nonzero()[0] // round_to) * round_to if first_non_zero_loc > 0: final_attention_mask = final_attention_mask[:, first_non_zero_loc:] final_embedding = final_embedding[:, first_non_zero_loc:] if not skip_labels: final_labels = final_labels[:, first_non_zero_loc:] final_input_ids = final_input_ids[:, first_non_zero_loc:] final_audio_in_mask = final_audio_in_mask[:, first_non_zero_loc:] final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, first_non_zero_loc:] final_audio_out_mask = final_audio_out_mask[:, first_non_zero_loc:] else: last_non_zero_loc = ((final_attention_mask.sum(0).nonzero()[-1] + 1 + round_to - 1) // round_to) * round_to if last_non_zero_loc < max_token_num: final_attention_mask = final_attention_mask[:, :last_non_zero_loc] final_embedding = final_embedding[:, :last_non_zero_loc] if not skip_labels: final_labels = final_labels[:, :last_non_zero_loc] final_input_ids = final_input_ids[:, :last_non_zero_loc] final_audio_in_mask = final_audio_in_mask[:, :last_non_zero_loc] final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, :last_non_zero_loc] final_audio_out_mask = final_audio_out_mask[:, :last_non_zero_loc] position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(final_attention_mask == 0, 1) return (final_embedding, final_attention_mask, final_labels, position_ids, final_input_ids, final_audio_in_mask, final_audio_in_discrete_codes_mask, final_audio_out_mask) from torch.nn import RMSNorm from cuda_graph_runner import CUDAGraphRunner class HiggsAudioDecoderProjector(HiggsAudioPreTrainedModel): """Projection layers that map hidden states from the LLM component to audio / text logits. We support two type of audio head: - Basic Audio Head: Directly map the hidden states to audio logits for all the codebooks. """ def __init__(self, config: HiggsAudioConfig, layer_idx: Optional[int] = None): super().__init__(config) self.text_lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.audio_lm_head = nn.Linear( config.text_config.hidden_size, config.audio_num_codebooks * (config.audio_codebook_size + 2), bias=False ) # Initialize weights and apply final processing self.post_init() def forward( self, hidden_states, audio_out_mask, label_audio_ids=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, output_audio_hidden_states=False, cache_position=None, ): """ Args: hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`): Hidden states from the LLM component audio_out_mask (`torch.Tensor` of shape `(batch_size, seq_len)`): Mask for identifying the audio out tokens. label_audio_ids (`torch.Tensor` of shape `(num_codebooks, num_audio_out_tokens)`): Label tokens for the audio-out part. This is used for calculating the logits if RQ-Transformer is used. attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`): Mask to avoid performing attention on padding token indices position_ids (`torch.Tensor` of shape `(batch_size, seq_len)`): Position ids for the input tokens Returns: logits (`torch.Tensor` of shape `(batch_size, seq_len, vocab_size)`): Logits for text tokens audio_logits (`torch.Tensor` of shape `(num_audio_out_tokens, audio_num_codebooks * audio_codebook_size)`): Logits for audio tokens. We ensure `num_text_tokens + num_audio_tokens == batch_size * seq_len` """ logits = self.text_lm_head(hidden_states) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None if self.config.audio_decoder_proj_num_layers > 0: # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.transformer_layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) if output_attentions: all_self_attns += (layer_outputs[1],) if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] next_cache = next_decoder_cache if use_cache else None audio_logits = self.audio_lm_head(hidden_states[audio_out_mask]) if output_audio_hidden_states: audio_hidden_states = hidden_states[audio_out_mask] else: audio_hidden_states = None return logits, audio_logits logger = logging.get_logger(__name__) class GenerationMode(Enum): """Enum for different generation modes in HiggsAudio model.""" TEXT = 0 # Text generation mode AUDIO_INIT = 1 # Audio generation mode initialization AUDIO_IN_PROGRESS = 2 # Audio generation mode in progress def _whisper_encoder_zero_shape_forward(whisper_encoder, *args, **kwargs): """The whisper encoder does not support zero-shape tensor by default due to the following implementations key_states = self._shape(self.k_proj(current_states), -1, bsz) If `bsz` is 0, the "-1" dimension will be ambiguous and triggers error in the shape inference pass. See also: https://github.com/huggingface/transformers/blob/30335093276212ce74938bdfd85bfd5df31a668a/src/transformers/models/whisper/modeling_whisper.py#L306-L307 This function monkey-patches all `_shape` functions in the whisper encoder's self-attention layers to ensure function supports zero-shape tensor. #FIXME!!!! This is a temporary workaround and should be removed once the upstream issue is resolved. """ global _higgs_flash_attention_forward def _patched_shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): if seq_len == -1: return tensor.view(bsz, tensor.shape[1], num_heads, head_dim).transpose(1, 2).contiguous() else: return tensor.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2).contiguous() def _patched_scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False ) -> torch.Tensor: # IMPORTANT! Implementation here is wrong and is only for the purpose of obtaining the correct attn_weight shape if enable_gqa: key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) attn_weight = query @ key.transpose(-2, -1) return attn_weight @ value # Apply monkey-patch if whisper_encoder.config._attn_implementation != "flash_attention_2": old_shape_functions = [] for layer in whisper_encoder.layers: old_shape_functions.append(getattr(layer.self_attn, "_shape")) layer.self_attn._shape = functools.partial( _patched_shape, num_heads=layer.self_attn.num_heads, head_dim=layer.self_attn.head_dim ) original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention torch.nn.functional.scaled_dot_product_attention = _patched_scaled_dot_product_attention out = whisper_encoder(*args, **kwargs) torch.nn.functional.scaled_dot_product_attention = original_scaled_dot_product_attention # Restore the original shape functions if whisper_encoder.config._attn_implementation != "flash_attention_2": for layer, old_shape_function in zip(whisper_encoder.layers, old_shape_functions): layer.self_attn._shape = old_shape_function return out def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, min_dtype: float, cache_position: torch.Tensor, batch_size: int, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to plcae the 4D attention mask on. min_dtype (`float`): The minimum value representable with the dtype `dtype`. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask class HiggsAudioFeatureProjector(nn.Module): """Projector that maps audio features extracted by Whisper to hidden state of the text model.""" def __init__(self, config: HiggsAudioConfig): super().__init__() self.linear = nn.Linear(config.audio_encoder_config.d_model, config.text_config.hidden_size, bias=True) def forward(self, audio_features): hidden_states = self.linear(audio_features) return hidden_states def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rope(xq, xk, freqs_cis): cos = freqs_cis[0].unsqueeze(1) sin = freqs_cis[1].unsqueeze(1) q_embed = (xq * cos) + (rotate_half(xq) * sin) k_embed = (xk * cos) + (rotate_half(xk) * sin) return q_embed, k_embed, sin, cos from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS class LLama3RoPE(nn.Module): def __init__(self, config, device = None, dtype = None): super().__init__() if config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq def _dynamic_frequency_update(self, position_ids, device): seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len, **self.rope_kwargs ) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() cos = cos * self.attention_scaling sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) import torch.nn.functional as F class Attention(nn.Module): def __init__(self, config, layer_idx: int, device, dtype, **kwargs): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) self.inner_size = config.num_attention_heads * self.head_dim def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.Tensor] = None, optimized_attention=None, ): batch_size, seq_length, _ = hidden_states.shape xq = self.q_proj(hidden_states) xk = self.k_proj(hidden_states) xv = self.v_proj(hidden_states) xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) xq, xk, sin, cos = apply_rope(xq, xk, freqs_cis=position_embeddings) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} xk, xv = past_key_value.update(xk, xv, self.layer_idx, cache_kwargs) xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) out = self.o_proj(output) return out, None, past_key_value class MLP(nn.Module): def __init__(self, config, device=None, dtype=None, ops: Any = nn): super().__init__() ops = ops or nn self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype) if config.mlp_activation == "silu": self.activation = torch.nn.functional.silu elif config.mlp_activation == "gelu_pytorch_tanh": self.activation = lambda a: torch.nn.functional.gelu(a, approximate="tanh") def forward(self, x): return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x)) SDP_BATCH_LIMIT = 2**15 def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, is_causal = False): if skip_reshape: b, _, _, dim_head = q.shape else: b, _, dim_head = q.shape dim_head //= heads q, k, v = map( lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v), ) if mask is not None: # add a batch dimension if there isn't already one if mask.ndim == 2: mask = mask.unsqueeze(0) # add a heads dimension if there isn't already one if mask.ndim == 3: mask = mask.unsqueeze(1) SDP_BATCH_LIMIT = b if SDP_BATCH_LIMIT >= b: out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=is_causal) if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) ) else: out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device) for i in range(0, b, SDP_BATCH_LIMIT): m = mask if mask is not None: if mask.shape[0] > 1: m = mask[i : i + SDP_BATCH_LIMIT] out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention( q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=m, dropout_p=0.0, is_causal=False ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) return out from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaMLP, LlamaRotaryEmbedding, LLAMA_ATTENTION_CLASSES class HiggsAudioDualFFNDecoderLayer(nn.Module): def __init__( self, config, layer_idx: int, fast_forward: bool = False, use_audio_attention: bool = False, device = None, dtype = None, ): super().__init__() text_config = config.text_config self.hidden_size = text_config.hidden_size self.layer_idx = layer_idx text_config.qkv_bias = text_config.mlp_bias text_config.mlp_activation = text_config.hidden_act self.self_attn = Attention(config=text_config, layer_idx=layer_idx, device = device, dtype = dtype) #self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=text_config, layer_idx=layer_idx) self.mlp = LlamaMLP(text_config) if not fast_forward: self.audio_mlp = LlamaMLP(text_config)#, device = device, dtype = dtype) self.audio_input_layernorm = LlamaRMSNorm(text_config.hidden_size, eps=text_config.rms_norm_eps)#, device = device, dtype = dtype) self.audio_post_attention_layernorm = LlamaRMSNorm(text_config.hidden_size, eps=text_config.rms_norm_eps)#, device = device, dtype = dtype) self.use_audio_attention = use_audio_attention self.fast_forward = fast_forward if self.fast_forward: assert not self.use_audio_attention, ( "We cannot use audio_attention if the layer is marked as fast-forward." ) self.input_layernorm = LlamaRMSNorm(text_config.hidden_size, eps=text_config.rms_norm_eps)#, device = device, dtype = dtype) self.post_attention_layernorm = LlamaRMSNorm(text_config.hidden_size, eps=text_config.rms_norm_eps)#, device = device, dtype = dtype) self.text_config = text_config def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, fast_forward_attention_mask: Optional[torch.Tensor] = None, audio_out_mask: Optional[torch.BoolTensor] = None, is_decoding_audio_token: Optional[bool] = None, past_key_value: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, is_using_cuda_graph: Optional[bool] = False, position_embeddings = None, **kwargs, ): residual = hidden_states target_length = hidden_states.shape[1] use_static_cache = isinstance(past_key_value, StaticCache) decode_stage = hidden_states.shape[1] == 1 if is_using_cuda_graph: assert decode_stage and use_static_cache, ( "The CUDA graph mode should only be used in the decoding stage with static cache." ) # If we are decoding an audio token and the layer is marked as fast-forward, # we can skip it. if is_decoding_audio_token and self.fast_forward: return (hidden_states,) has_audio_out = audio_out_mask is not None and audio_out_mask.shape[0] > 0 audio_out_mask_sq = audio_out_mask # considering that hidden_state = audio + text, I chose to make small_input dynamic small_input = target_length <= 2048 optimized_attention = attention_pytorch#optimized_attention_for_device(hidden_states.device, small_input = small_input) if self.fast_forward and has_audio_out: original_hidden_states = hidden_states.clone() min_dtype = torch.finfo(hidden_states.dtype).min if attention_mask is None: attention_mask = ~audio_out_mask if optimized_attention.__name__ != "attention_flash": sequence_length = audio_out_mask.shape[1] attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( attention_mask=attention_mask, sequence_length=sequence_length, target_length=sequence_length, dtype=hidden_states.dtype, min_dtype=min_dtype, device=hidden_states.device, cache_position=cache_position, batch_size=hidden_states.shape[0], ) if use_cache: attention_mask = attention_mask[:, :, -target_length:, :] elif len(attention_mask.shape) == 2: # Attention mask has shape (batch_size, sequence_length) # We should be using flash attention 2 attention_mask = attention_mask * ~audio_out_mask elif len(attention_mask.shape) == 4: # When using static cache, the attention mask was already preprocessed in the previous layer if use_static_cache: attention_mask = fast_forward_attention_mask else: if use_cache: attention_mask = attention_mask.masked_fill( audio_out_mask[:, -target_length:].reshape(audio_out_mask.shape[0], 1, target_length, 1) | audio_out_mask.reshape(audio_out_mask.shape[0], 1, 1, audio_out_mask.shape[1]), min_dtype, ) else: attention_mask = attention_mask.masked_fill( audio_out_mask.reshape(audio_out_mask.shape[0], 1, audio_out_mask.shape[1], 1) | audio_out_mask.reshape(audio_out_mask.shape[0], 1, 1, audio_out_mask.shape[1]), min_dtype, ) else: raise NotImplementedError(f"Unsupported attention_mask format, attention_mask={attention_mask}") if ( optimized_attention.__name__ == "attention_pytorch" and attention_mask is not None and attention_mask.device.type == "cuda" ): attention_mask = attention_mask.mul(~torch.all(attention_mask == min_dtype, dim=-1, keepdim=True)) if has_audio_out and not self.fast_forward: # Apply separate layernorm layers for audio tokens and text tokens if use_cache: hidden_states = torch.where( audio_out_mask_sq[:, -target_length:].unsqueeze(-1), self.audio_input_layernorm(hidden_states), self.input_layernorm(hidden_states), ) else: hidden_states = torch.where( audio_out_mask_sq.unsqueeze(-1), self.audio_input_layernorm(hidden_states), self.input_layernorm(hidden_states), ) else: hidden_states = self.input_layernorm(hidden_states) # Text Attention #freqs_cis = precompute_freqs_cis(self.text_config.head_dim, hidden_states.shape[1], self.text_config.rope_theta, device = hidden_states.device) hidden_states, _, present_key_value = self.self_attn( hidden_states = hidden_states, attention_mask = attention_mask, position_embeddings = position_embeddings, past_key_value=past_key_value, optimized_attention = optimized_attention, cache_position=cache_position, ) hidden_states = residual + hidden_states residual = hidden_states if has_audio_out and not self.fast_forward: if use_cache: real_audio_out_mask = audio_out_mask_sq[:, -target_length:] else: real_audio_out_mask = audio_out_mask_sq # Make whole graph in decode stage if decode_stage and is_using_cuda_graph: assert is_decoding_audio_token is not None, ( "is_decoding_audio_token should be present in the decoding stage." ) if is_decoding_audio_token: hidden_states = self.audio_post_attention_layernorm(hidden_states) hidden_states = self.audio_mlp(hidden_states) else: hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) residual = residual + hidden_states else: text_hidden_states = self.post_attention_layernorm(hidden_states[~real_audio_out_mask]) audio_hidden_states = self.audio_post_attention_layernorm(hidden_states[real_audio_out_mask]) text_hidden_states = self.mlp(text_hidden_states) residual[~real_audio_out_mask] += text_hidden_states audio_hidden_states = self.audio_mlp(audio_hidden_states) residual[real_audio_out_mask] += audio_hidden_states hidden_states = residual else: hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states if self.fast_forward and has_audio_out: if use_cache: hidden_states = torch.where( audio_out_mask_sq[:, -target_length:].unsqueeze(-1), original_hidden_states, hidden_states ) else: hidden_states = torch.where(audio_out_mask_sq.unsqueeze(-1), original_hidden_states, hidden_states) outputs = (hidden_states,) if use_cache: outputs += (present_key_value,) return outputs @dataclass class HiggsAudioModelOutputWithPast(ModelOutput): loss: Optional[torch.FloatTensor] = None llm_loss: Optional[torch.FloatTensor] = None audio_loss: Optional[torch.FloatTensor] = None codebook_losses: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None expanded_input_ids: Optional[torch.LongTensor] = None expanded_labels: Optional[torch.LongTensor] = None audio_in_mask: Optional[torch.BoolTensor] = None audio_in_discrete_codes_mask: Optional[torch.BoolTensor] = None audio_out_mask: Optional[torch.BoolTensor] = None attention_mask: Optional[torch.BoolTensor] = None audio_logits: Optional[torch.FloatTensor] = None past_key_values: Optional[Cache] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None audio_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None @dataclass class HiggsAudioGenerationOutput(ModelOutput): """ Outputs of HiggsAudio generation models, when using non-beam methods. Args: sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. audio_sequences (`tuple(torch.LongTensor)` *optional*): The generated discrete audio codes. These codes can be used to fill-in related locations of <|AUDIO_OUT|> at input sequences. scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token). If the generated token is a text token, the tensor will have shape `(batch_size, config.vocab_size)`. If the generated token is an audio token, the tensor will have shape `(config.audio_num_codebooks, self.audio_codebook_size)` logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head or the audio head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token). If the generated token is a text token, the tensor will have shape `(batch_size, config.vocab_size)`. If the generated token is an audio token, the tensor will have shape `(config.audio_num_codebooks, self.audio_codebook_size)` attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): Returns the model cache, used to speed up decoding. Different models have a different cache format, check the model's documentation. Usually, a [`~cache_utils.Cache`] instance. """ sequences: torch.LongTensor = None audio_sequences: Optional[List[torch.LongTensor]] = None scores: Optional[Tuple[torch.FloatTensor]] = None logits: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None class HiggsAudioModel(HiggsAudioPreTrainedModel, GenerationMixin): """Higgs-Audio is an end-to-end multimodal model with the capability to understand and generate text / audio. Consider the following example for mixed text/audio understanding / generation: - input_tokens: <|audio_bos|>[AUDIO]<|audio_eos|><|audio_bos|>[AUDIO]<|audio_eos|> - input_tokens: <|audio_bos|>[AUDIO]<|audio_eos|><|audio_out_bos|>[AUDIO_OUT]<|audio_eos|> We will fill [AUDIO] with the audio features extracted by Whisper and fill [AUDIO_OUT] with the audio tokens. Consider the following example for mixed text/audio generation: text: <|audio_out_bos|> MASK MASK MASK MASK MASK <|audio_eos|> [text_token1] audio: MASK <|audio_stream_bos|> [audio_token1] [audio_token2] [audio_token3] <|audio_stream_eos|> MASK MASK token_type: 0 1 1 1 1 1 0 0 """ _supports_cache_class = True _supports_static_cache = True def __init__(self, config: HiggsAudioConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.audio_in_token_idx = config.audio_in_token_idx self.audio_out_token_idx = config.audio_out_token_idx self.audio_out_bos_token_id = config.audio_out_bos_token_id if "audio_out_bos_token_id" in config else None self.audio_eos_token_id = config.audio_eos_token_id if "audio_eos_token_id" in config else None self.vocab_size = config.text_config.vocab_size self.audio_num_codebooks = config.audio_num_codebooks self.use_delay_pattern = config.use_delay_pattern self.use_audio_out_embed_projector = config.use_audio_out_embed_projector self.use_audio_out_self_attention = config.use_audio_out_self_attention self.embed_tokens = nn.Embedding(self.vocab_size, config.text_config.hidden_size, self.padding_idx) if config.audio_adapter_type == "dual_ffn": layer_idx = 0 layers = [] for j in range(config.text_config.num_hidden_layers): if j in config.audio_dual_ffn_layers: layers.append( HiggsAudioDualFFNDecoderLayer( config, layer_idx, use_audio_attention=self.use_audio_out_self_attention ) ) layer_idx += 2 if self.use_audio_out_self_attention else 1 else: layers.append(LlamaDecoderLayer(config.text_config, layer_idx)) layer_idx += 1 self.layers = nn.ModuleList(layers) elif config.audio_adapter_type == "dual_ffn_fast_forward": layer_idx = 0 layers = [] for j in range(config.text_config.num_hidden_layers): if j in config.audio_dual_ffn_layers: layers.append( HiggsAudioDualFFNDecoderLayer( config, layer_idx, fast_forward=False, use_audio_attention=self.use_audio_out_self_attention, ) ) layer_idx += 2 if self.use_audio_out_self_attention else 1 else: layers.append( HiggsAudioDualFFNDecoderLayer(config, layer_idx, fast_forward=True, use_audio_attention=False) ) layer_idx += 1 self.layers = nn.ModuleList(layers) elif config.audio_adapter_type == "stack": self.layers = nn.ModuleList( [ LlamaDecoderLayer(config.text_config, layer_idx) for layer_idx in range(config.text_config.num_hidden_layers) ] ) layer_idx = config.text_config.num_hidden_layers else: raise NotImplementedError(f"Audio adapter type {config.audio_adapter_type} not implemented.") self.num_activation_checkpointing_layers = len(self.layers) self.decode_graph_runners = defaultdict(dict[bool, CUDAGraphRunner]) self.norm = RMSNorm(config.text_config.hidden_size, eps=config.text_config.rms_norm_eps) self.rotary_emb = LLama3RoPE(config = config.text_config)#LlamaRotaryEmbedding(config=config.text_config) if not config.skip_audio_tower: self.audio_tower = HiggsAudioEncoder(config.audio_encoder_config) self.audio_encoder_proj = HiggsAudioFeatureProjector(config) else: self.audio_tower = None self.audio_encoder_proj = None self.audio_decoder_proj = HiggsAudioDecoderProjector(config, layer_idx=layer_idx) self.audio_codebook_size = ( config.audio_codebook_size + 2 ) # We add 1 for the audio_stream_bos token and 1 for the audio_stream_eos token if config.use_audio_out_embed_projector: self.audio_out_embed_projector = nn.Linear( config.text_config.hidden_size, config.text_config.hidden_size, bias=False ) self.audio_codebook_embeddings = nn.Embedding( config.audio_num_codebooks * self.audio_codebook_size, config.text_config.hidden_size ) self.audio_codebook_weights = ( torch.ones(config.audio_num_codebooks) / config.audio_num_codebooks ) # default to equal weights self.post_init() def set_num_activation_checkpointing_layers(self, num_layers): self.num_activation_checkpointing_layers = num_layers def set_delay_pattern(self): self.config.use_delay_pattern = True self.use_delay_pattern = True def set_audio_special_tokens(self, tokenizer: AutoTokenizer): self.audio_out_bos_token_id = tokenizer.convert_tokens_to_ids("<|audio_out_bos|>") self.audio_eos_token_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>") def _embed_audio_ids(self, audio_ids): """Embed the audio ids Args: audio_ids: torch.LongTensor of shape (num_codebooks, audio_in_total_length) Returns: audio_embed: torch.LongTensor of shape (audio_in_total_length, hidden_size) """ codebook_shift = ( torch.arange(self.config.audio_num_codebooks, device=audio_ids.device) * self.audio_codebook_size ) audio_embed = self.audio_codebook_embeddings(audio_ids + codebook_shift.unsqueeze(-1)) if self.config.audio_embed_avg: audio_embed = torch.mean(audio_embed, dim=0) else: audio_embed = torch.sum(audio_embed, dim=0) if self.use_audio_out_embed_projector: audio_embed = self.audio_out_embed_projector(audio_embed) return audio_embed def _apply_audio_tower(self, audio_features, audio_feature_attention_mask): """Apply the audio tower to the audio features""" if audio_features.shape[0] == 0: if torch.is_grad_enabled(): # FIXME!!!!!!!! # This is a hack to ensure that the forward+backward pass of audio_tower and audio_encoder_proj get triggered. # The monkey patch won't overwrite the backward pass of nn.Module. audio_outputs = _whisper_encoder_zero_shape_forward( self.audio_tower, audio_features, attention_mask=None, check_seq_length=False ) selected_audio_feature = audio_outputs.last_hidden_state audio_features_embed = self.audio_encoder_proj(selected_audio_feature) audio_feat_out_lengths = None return audio_features_embed, audio_feat_out_lengths else: return None, None audio_feat_lengths, audio_feat_out_lengths = self.audio_tower._get_feat_extract_output_lengths( audio_feature_attention_mask.sum(-1) ) batch_size, _, max_mel_seq_len = audio_features.shape max_seq_len = (max_mel_seq_len - 1) // 2 + 1 # Create a sequence tensor of shape (batch_size, max_seq_len) seq_range = ( torch.arange(0, max_seq_len, dtype=audio_feat_lengths.dtype, device=audio_feat_lengths.device) .unsqueeze(0) .expand(batch_size, max_seq_len) ) lengths_expand = audio_feat_lengths.unsqueeze(1).expand(batch_size, max_seq_len) # Create mask padding_mask = seq_range < lengths_expand if self.config._attn_implementation != "flash_attention_2": audio_attention_mask = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( batch_size, 1, max_seq_len, max_seq_len ) else: audio_attention_mask = padding_mask audio_outputs = self.audio_tower(audio_features, attention_mask=audio_attention_mask) selected_audio_feature = audio_outputs.last_hidden_state audio_features_embed = self.audio_encoder_proj(selected_audio_feature) return audio_features_embed, audio_feat_out_lengths def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, min_dtype=min_dtype, cache_position=cache_position, batch_size=input_tensor.shape[0], ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask def _prepare_all_static_kv_cache_masks(self, hidden_states, attention_mask, audio_out_mask, past_key_values): target_length = hidden_states.shape[1] cur_pos = audio_out_mask.shape[1] min_dtype = torch.finfo(hidden_states.dtype).min assert len(attention_mask.shape) == 4, "Only support SDPA for now" kv_cache_len = past_key_values.get_max_cache_shape() audio_out_mask_padded = torch.nn.functional.pad(audio_out_mask, (0, kv_cache_len - cur_pos), value=True) fast_forward_attention_mask = attention_mask.masked_fill( audio_out_mask_padded[:, audio_out_mask.shape[1] - target_length : audio_out_mask.shape[1]].reshape( audio_out_mask_padded.shape[0], 1, target_length, 1 ) | audio_out_mask_padded.reshape(audio_out_mask_padded.shape[0], 1, 1, audio_out_mask_padded.shape[1]), min_dtype, ) no_audio_out_mask = ~audio_out_mask no_audio_out_mask = torch.nn.functional.pad( no_audio_out_mask, (0, kv_cache_len - audio_out_mask.shape[1]), value=False ) no_audio_out_mask = no_audio_out_mask[ :, audio_out_mask.shape[1] - target_length : audio_out_mask.shape[1] ].reshape(audio_out_mask.shape[0], 1, target_length, 1) | no_audio_out_mask.reshape( audio_out_mask.shape[0], 1, 1, kv_cache_len ) audio_attention_mask = attention_mask.masked_fill(no_audio_out_mask, min_dtype) return fast_forward_attention_mask, audio_attention_mask def _forward_core( self, hidden_states: torch.Tensor, causal_mask: torch.Tensor, position_ids: torch.Tensor, audio_discrete_codes_mask: torch.Tensor, cache_position: torch.Tensor, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]], use_cache: bool, audio_attention_mask: torch.Tensor, fast_forward_attention_mask: torch.Tensor, output_attentions: bool, output_hidden_states: bool, is_decoding_audio_token: Optional[bool] = None, is_using_cuda_graph: Optional[bool] = False, ): # create position embeddings to be shared across the decoder layers # When past_key_values is passed in, we need to offset the position ids when calculating the position embeddings. # Therefore, cache_position is used. position_id_offset = cache_position[0] if use_cache else 0 position_embeddings = self.rotary_emb(hidden_states, position_ids + position_id_offset) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if isinstance(decoder_layer, HiggsAudioDualFFNDecoderLayer): layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, audio_attention_mask=audio_attention_mask, fast_forward_attention_mask=fast_forward_attention_mask, position_ids=position_ids, audio_out_mask=audio_discrete_codes_mask, is_decoding_audio_token=is_decoding_audio_token, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, is_using_cuda_graph=is_using_cuda_graph, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) return hidden_states, all_hidden_states, all_self_attns def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.BoolTensor] = None, audio_in_ids: Optional[torch.LongTensor] = None, audio_in_ids_start: Optional[torch.LongTensor] = None, audio_out_ids: Optional[torch.LongTensor] = None, audio_out_ids_start: Optional[torch.LongTensor] = None, label_ids: Optional[torch.LongTensor] = None, label_audio_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_audio_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, cache_audio_discrete_codes_mask: Optional[torch.LongTensor] = None, past_key_values_buckets: Optional[OrderedDict[int, Cache]] = None, reward = None, audio_features = None, **kwargs #**kwargs ): target_device = input_ids.device inputs_embeds = self.embed_tokens(input_ids) if self.config.encode_audio_in_tokens: if audio_in_ids is not None and audio_in_ids.shape[-1] > 0: audio_in_ids = audio_in_ids.to(target_device) else: audio_in_ids = torch.zeros((self.audio_num_codebooks, 0), device=target_device, dtype=torch.long) audio_in_embed = self._embed_audio_ids(audio_in_ids) else: audio_in_embed = None if audio_out_ids is not None and audio_out_ids.shape[-1] > 0: audio_out_ids = audio_out_ids.to(target_device) else: audio_out_ids = torch.zeros((self.audio_num_codebooks, 0), device=target_device, dtype=torch.long) audio_out_embed = self._embed_audio_ids(audio_out_ids) round_to = 1 if use_cache else 8 left_padding = True if use_cache or input_ids.shape[0] == 1 else False ( inputs_embeds, attention_mask, labels, position_ids, input_ids, audio_in_mask, audio_in_discrete_codes_mask, audio_out_mask, ) = merge_input_ids_with_audio_features( audio_in_embed, audio_in_ids_start, audio_out_embed, audio_out_ids_start, self.audio_in_token_idx, self.audio_out_token_idx, inputs_embeds, input_ids, attention_mask, label_ids, pad_token_id=self.padding_idx, round_to=round_to, left_padding=left_padding, ) # re-check if we use the correct kv cache bucket after # the input_embeds has been merged with audio features if past_key_values_buckets is not None and inputs_embeds.shape[1] > past_key_values.get_max_cache_shape(): past_key_values, self.current_past_key_values_bucket = self._prepare_kv_cache( inputs_embeds.shape[1], None, past_key_values_buckets ) if use_cache and past_key_values is None: past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if isinstance(past_key_values, StaticCache) and past_seen_tokens >= past_key_values.get_max_cache_shape(): raise ValueError( f"The current sequence length ({past_seen_tokens}) exceeds " f"the maximum cache shape. " f"Please consider increasing the cache size." ) # Use torch compile use_static_cache = isinstance(past_key_values, StaticCache) # Apply the LLM component causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) hidden_states = inputs_embeds audio_discrete_codes_mask = audio_in_discrete_codes_mask | audio_out_mask if cache_audio_discrete_codes_mask is not None and use_cache: audio_discrete_codes_mask = torch.concat( [cache_audio_discrete_codes_mask, audio_discrete_codes_mask], dim=1 ) # Generate the audio attention mask outside the layer to avoid recompilation if use_static_cache: fast_forward_attention_mask, audio_attention_mask = self._prepare_all_static_kv_cache_masks( hidden_states, causal_mask, audio_discrete_codes_mask, past_key_values ) # Set the audio out mask to the last token if hidden_states.shape[1] == 1: audio_discrete_codes_mask = audio_discrete_codes_mask[:, -1:] audio_discrete_codes_mask = audio_discrete_codes_mask.reshape((-1, 1)).contiguous() is_decoding_audio_token = audio_discrete_codes_mask.item() else: is_decoding_audio_token = False if ( past_key_values is not None and past_key_values.get_max_cache_shape() in self.decode_graph_runners and (input_ids.shape[-1] == 1) ): _forward_core = self.decode_graph_runners[past_key_values.get_max_cache_shape()][is_decoding_audio_token] is_using_cuda_graph = True else: _forward_core = self._forward_core is_using_cuda_graph = False hidden_states = _forward_core( hidden_states=hidden_states, causal_mask=causal_mask, position_ids=position_ids, audio_discrete_codes_mask=audio_discrete_codes_mask, is_decoding_audio_token=is_decoding_audio_token if use_static_cache else None, cache_position=cache_position, past_key_values=past_key_values, use_cache=use_cache, audio_attention_mask=audio_attention_mask if use_static_cache else None, fast_forward_attention_mask=fast_forward_attention_mask if use_static_cache else None, is_using_cuda_graph=is_using_cuda_graph, output_hidden_states = False, output_attentions = False ) #print(hidden_states) hidden_states = self.norm(hidden_states[0]) # Apply the audio decoder projector logits, audio_logits = ( self.audio_decoder_proj( hidden_states, audio_out_mask, label_audio_ids=label_audio_ids, attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_audio_hidden_states=output_audio_hidden_states, cache_position=cache_position, ) ) if audio_logits is not None: audio_logits = audio_logits.view( audio_logits.shape[0], self.audio_num_codebooks, self.audio_codebook_size ).float() next_cache = past_key_values if use_cache else None ret = HiggsAudioModelOutputWithPast( logits=logits, audio_logits=audio_logits, past_key_values=next_cache, audio_out_mask = audio_out_mask, audio_in_discrete_codes_mask = audio_in_discrete_codes_mask ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if not return_dict: outputs = ret.to_tuple() return outputs return ret # Overwrite GenerationMixin._update_model_kwargs_for_generation def _update_model_kwargs_for_generation( self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, num_new_tokens: int = 1, extend_attention_mask: bool = True, ) -> Dict[str, Any]: """Update the model kwargs for each step.""" model_kwargs["past_key_values"] = outputs.past_key_values # update attention mask if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] if extend_attention_mask: model_kwargs["attention_mask"] = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 ) if "cache_audio_discrete_codes_mask" in model_kwargs: if model_kwargs["cache_audio_discrete_codes_mask"] is None: model_kwargs["cache_audio_discrete_codes_mask"] = ( outputs.audio_in_discrete_codes_mask | outputs.audio_out_mask ) else: model_kwargs["cache_audio_discrete_codes_mask"] = torch.concat( [ model_kwargs["cache_audio_discrete_codes_mask"], outputs.audio_in_discrete_codes_mask | outputs.audio_out_mask, ], 1, ) return model_kwargs def _copy_kv_cache(self, from_cache: Cache, to_cache: Cache): num_layers = self.config.text_config.num_hidden_layers if self.config.audio_dual_ffn_layers is not None: num_layers += len(self.config.audio_dual_ffn_layers) """ Copy the key-value pairs from one cache to another. """ for layer_idx in range(num_layers): from_cache_size = from_cache.get_max_cache_shape() assert to_cache.get_max_cache_shape() >= from_cache_size, ( f"The target cache size {to_cache.get_max_cache_shape()} is smaller than the source cache size {from_cache_size}." ) to_cache.key_cache[layer_idx][:, :, :from_cache_size, :] = from_cache.key_cache[layer_idx] to_cache.value_cache[layer_idx][:, :, :from_cache_size, :] = from_cache.value_cache[layer_idx] def _prepare_kv_cache( self, current_sequence_length: int, current_past_key_values_bucket: Optional[int], past_key_values_buckets: OrderedDict[int, Cache], ) -> Tuple[Optional[Cache], Optional[int]]: """Prepare the KV cache for the current sequence length.""" for cache_length in past_key_values_buckets.keys(): if cache_length >= current_sequence_length: # Promote to the next KV cache bucket, copy the current KV cache bucket # to the new one. if current_past_key_values_bucket is not None and cache_length != current_past_key_values_bucket: self._copy_kv_cache( past_key_values_buckets[current_past_key_values_bucket], past_key_values_buckets[cache_length] ) return past_key_values_buckets[cache_length], cache_length raise ValueError( f"The current sequence length {current_sequence_length} is larger than " f"all past key values buckets {past_key_values_buckets.keys()}." ) def _sample_audio_tokens( self, hidden_states: torch.Tensor, audio_logits: torch.Tensor, audio_out_ids: torch.Tensor, do_sample: bool, logits_processor: LogitsProcessorList, device: torch.device, torch_generator: Optional[torch.Generator], generation_config: GenerationConfig, num_delay: int, num_remaining_delays: Optional[int], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[int]]: """Sample audio tokens and its corresponding text tokens from the logits""" # parameters related to repetition aware sampling ras_win_len = generation_config.generation_kwargs.get("ras_win_len", None) ras_win_max_num_repeat = generation_config.generation_kwargs.get("ras_win_max_num_repeat", 2) audio_eos_token_id = generation_config.generation_kwargs.get("audio_eos_token_id", None) # In the audio generation mode, we sample from audio_logits and keep updating audio_out_ids. next_audio_token_logits = audio_logits.clone()[-1, :, :].float().to(device) # TopP, TopK logits processor supports empty input_ids next_audio_token_scores = logits_processor(None, next_audio_token_logits) # token selection if do_sample: # next_audio_token_scores has been applied top_p, top_k, and temperature. probs = nn.functional.softmax(next_audio_token_scores, dim=-1) # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution next_audio_tokens = torch.multinomial(probs, num_samples=1, generator=torch_generator).squeeze(1) else: next_audio_tokens = torch.argmax(next_audio_token_scores, dim=-1) # next_tokens: (num_codebooks, ) if ras_win_len is not None: # check if there are repetitions over a window of tokens. rep_num = (audio_out_ids[:, -ras_win_len:] == next_audio_tokens.unsqueeze(1)).sum(dim=1) # if we saw repeated tokens in the most recent window of tokens, resample without temperature. row_indices = torch.nonzero(rep_num >= ras_win_max_num_repeat).squeeze(1) resampled_next_tokens = ( next_audio_token_logits[row_indices] .softmax(dim=-1) .multinomial(1, replacement=True, generator=torch_generator) .squeeze(1) ) next_audio_tokens[row_indices] = resampled_next_tokens # Force the next text tokens to be <|AUDIO_OUT|> in audio generation mode next_tokens = torch.full( (audio_logits.shape[0],), self.config.audio_out_token_idx, dtype=torch.long, device=device, ) # Handle delay_pattern if self.use_delay_pattern: if num_delay + 1 < next_audio_tokens.shape[0]: next_audio_tokens[(num_delay + 1) :] = self.config.audio_stream_bos_id num_delay += 1 if num_remaining_delays is not None: next_audio_tokens[: (self.audio_num_codebooks - num_remaining_delays)] = ( self.config.audio_stream_eos_id ) num_remaining_delays -= 1 else: all_eos_indices = (next_audio_tokens == self.config.audio_stream_eos_id).nonzero() if torch.numel(all_eos_indices) > 0: all_eos_indices = all_eos_indices[0] last_eos_idx = all_eos_indices[-1] next_audio_tokens[:last_eos_idx] = self.config.audio_stream_eos_id num_remaining_delays = self.audio_num_codebooks - last_eos_idx - 1 if num_remaining_delays is not None and num_remaining_delays <= 0: next_tokens[...] = audio_eos_token_id num_delay = 0 num_remaining_delays = None return ( next_tokens, next_audio_tokens, next_audio_token_logits, next_audio_token_scores, num_delay, num_remaining_delays, ) def _sample_text_tokens( self, logits: torch.Tensor, input_ids: torch.Tensor, do_sample: bool, logits_processor: LogitsProcessorList, device: torch.device, generation_mode: GenerationMode, torch_generator: Optional[torch.Generator], ) -> torch.Tensor: """Sample text tokens from the logits""" # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) next_token_logits = logits.clone()[:, -1, :].float() next_token_logits = next_token_logits.to(input_ids.device) # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) if generation_mode == GenerationMode.AUDIO_INIT: # See the audio bos token, we should start generating audio tokens next_tokens = torch.full( (input_ids.shape[0],), self.audio_out_token_idx, dtype=torch.long, device=device, ) next_audio_tokens = torch.full( (self.config.audio_num_codebooks,), self.config.audio_stream_bos_id, dtype=torch.long, device=device, ) else: if do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution next_tokens = torch.multinomial(probs, num_samples=1, generator=torch_generator).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) next_audio_tokens = None return next_tokens, next_audio_tokens, next_token_logits, next_token_scores # Built on top of GenerationMixin._sample. # We revise the implementation to support generating both audio / text. def _sample( self, input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, streamer: Optional["BaseStreamer"], past_key_values_buckets: Optional[OrderedDict[int, Cache]], **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" Generates sequences of token ids for joint text/audio models using **multinomial sampling**. This function may also be revised to support generating samples from HiggsAudio-like end-to-end text/audio models built on top of LLMs. If the input_ids ends with <|audio_out_bos|>, we will switch to the audio-generation mode. ``` ...<|start_header_id|>assistant<|end_header_id|>\n\n<|audio_out_bos|> ``` Otherwise, we will keep generating the text tokens. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. generation_config ([`~generation.GenerationConfig`]): The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): Whether to continue running the while loop until max_length (needed to avoid deadlocking with `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. """ assert input_ids.shape[0] == 1, "Only support batch_size=1 in _sample()" audio_out_bos_token_id = generation_config.generation_kwargs.get("audio_out_bos_token_id", None) # torch generator for sampling seed = generation_config.generation_kwargs.get("seed", None) if seed is not None: torch_generator = torch.Generator(device=input_ids.device).manual_seed(seed) else: torch_generator = None # init values pad_token_id = generation_config._pad_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate max_length = generation_config.max_length has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) do_sample = generation_config.do_sample # Used to track which past_key_va self.current_past_key_values_bucket = None # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # keep track of which sequences are already finished batch_size, cur_len = input_ids.shape this_peer_finished = False unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) if generation_config.use_cache: model_kwargs["cache_audio_discrete_codes_mask"] = None init_model_input = True num_delay = 0 num_remaining_delays = None audio_sequences = [] # A tensor to keep track of all the audio placeholder tokens. input_ids_full = input_ids.clone() # Initialize the audio variables based on the input prompt. if input_ids[0][-1] == self.config.audio_out_token_idx: audio_sequences = [model_kwargs["audio_out_ids"][:, model_kwargs["audio_out_ids_start"][-1] :]] if self.use_delay_pattern: num_delay = ( self.audio_num_codebooks - (model_kwargs["audio_out_ids"][:, -1] == self.config.audio_stream_bos_id).sum() ) all_eos_indices = (model_kwargs["audio_out_ids"][:, -1] == self.config.audio_stream_eos_id).nonzero() if torch.numel(all_eos_indices) > 0: all_eos_indices = all_eos_indices[0] last_eos_idx = all_eos_indices[-1] num_remaining_delays = self.audio_num_codebooks - last_eos_idx - 1 while self._has_unfinished_sequences( this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length ): # Check which multimodal stage we are in # FIXME: Assume single input generation if input_ids[0][-1] == audio_out_bos_token_id: generation_mode = GenerationMode.AUDIO_INIT elif input_ids[0][-1] == self.audio_out_token_idx: generation_mode = GenerationMode.AUDIO_IN_PROGRESS else: generation_mode = GenerationMode.TEXT is_audio_generation_mode = generation_mode == GenerationMode.AUDIO_IN_PROGRESS if init_model_input or not generation_config.use_cache: model_inputs = {"input_ids": input_ids, **model_kwargs} else: model_inputs = {"input_ids": input_ids[:, -1:], **model_kwargs} if is_audio_generation_mode and generation_config.use_cache: model_inputs["audio_out_ids"] = model_kwargs["audio_out_ids"][:, -1:] model_inputs["audio_out_ids_start"] = torch.tensor([0], dtype=torch.long, device=input_ids.device) elif not is_audio_generation_mode: del model_inputs["audio_out_ids"] del model_inputs["audio_out_ids_start"] if generation_config.use_cache: if "audio_features" in model_inputs and model_inputs["audio_features"] is not None: model_inputs["audio_features"] = model_inputs["audio_features"][:0, ...] model_inputs["audio_feature_attention_mask"] = model_inputs["audio_feature_attention_mask"][ :0, ... ] if "audio_in_ids" in model_inputs and model_inputs["audio_in_ids"] is not None: model_inputs["audio_in_ids"] = None model_inputs["audio_in_ids_start"] = None # prepare variable output controls (note: some models won't accept all output controls) model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) if past_key_values_buckets is not None: past_key_values, self.current_past_key_values_bucket = self._prepare_kv_cache( cur_len, self.current_past_key_values_bucket, past_key_values_buckets ) if past_key_values is not None: model_inputs.update({"past_key_values": past_key_values}) model_inputs["past_key_values_buckets"] = past_key_values_buckets # forward pass to get next token outputs = self(**model_inputs, return_dict=True) # Update the actual sequence length after the first forward pass if init_model_input and past_key_values_buckets is not None: cur_len = past_key_values_buckets[self.current_past_key_values_bucket].get_seq_length().item() # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, extend_attention_mask=True, ) # After the first forward pass, we can set init_model_input to False. init_model_input = False if synced_gpus and this_peer_finished: continue if is_audio_generation_mode: # In audio generation mode, we sample the audio tokens from audio logits. # It might also generate the audio eos token to end the audio generation. ( next_tokens, next_audio_tokens, next_audio_token_logits, next_audio_token_scores, num_delay, num_remaining_delays, ) = self._sample_audio_tokens( hidden_states=outputs.audio_hidden_states, audio_logits=outputs.audio_logits, audio_out_ids=model_kwargs["audio_out_ids"], do_sample=do_sample, logits_processor=logits_processor, device=input_ids.device, torch_generator=torch_generator, generation_config=generation_config, num_delay=num_delay, num_remaining_delays=num_remaining_delays, ) # update generated ids, model inputs, and length for next step model_kwargs["audio_out_ids"] = torch.cat( [model_kwargs["audio_out_ids"], next_audio_tokens[:, None]], dim=-1 ) audio_sequences[-1] = torch.cat([audio_sequences[-1], next_audio_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_audio_tokens.cpu()) else: # In text generation mode, we sample the text tokens from text logits. # It might also generate the audio placeholder token to start the audio generation. next_tokens, next_audio_tokens, next_token_logits, next_token_scores = self._sample_text_tokens( input_ids=input_ids, logits=outputs.logits, do_sample=do_sample, logits_processor=logits_processor, device=input_ids.device, generation_mode=generation_mode, torch_generator=torch_generator, ) if streamer is not None: streamer.put(next_tokens.cpu()) if next_audio_tokens is not None: # If the token is audio bos token, we will generate the audio placeholder token # and the corrensponding audio stream bos token to start the audio generation. audio_sequences.append(next_audio_tokens[:, None]) if streamer is not None: streamer.put(next_audio_tokens.cpu()) if model_kwargs["audio_out_ids"] is None or model_kwargs["audio_out_ids"].shape[0] == 0: # Initialize audio_out_ids model_kwargs["audio_out_ids"] = next_audio_tokens[:, None] model_kwargs["audio_out_ids_start"] = torch.tensor( [0], dtype=torch.long, device=input_ids.device ) else: model_kwargs["audio_out_ids_start"] = torch.concat( [ model_kwargs["audio_out_ids_start"], torch.tensor( [model_kwargs["audio_out_ids"].shape[1]], dtype=torch.long, device=input_ids.device ), ], dim=0, ) model_kwargs["audio_out_ids"] = torch.concat( [model_kwargs["audio_out_ids"], next_audio_tokens[:, None]], dim=1 ) if return_dict_in_generate: if output_scores: if is_audio_generation_mode: scores += (next_audio_token_scores,) else: scores += (next_token_scores,) if output_logits: if is_audio_generation_mode: raw_logits += (next_audio_token_logits,) else: raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += (outputs.attentions,) if output_hidden_states: decoder_hidden_states += (outputs.hidden_states,) # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) if "tokenizer_length" in generation_config.generation_kwargs: tokenizer_length = generation_config.generation_kwargs["tokenizer_length"] if torch.max(next_tokens) >= tokenizer_length: raise ValueError( f"Next generated token has max value {torch.max(next_tokens)} which is greater than the tokenizer's vocabulary size {tokenizer_length}, this is undesired behavior." ) # update generated ids, model inputs, and length for next step if not is_audio_generation_mode or next_tokens[0] != self.audio_out_token_idx: # We only add one <|AUDIO_OUT|> token to the input_ids for simplicity. input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) input_ids_full = torch.cat([input_ids_full, next_tokens[:, None]], dim=-1) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids_full, scores) this_peer_finished = unfinished_sequences.max() == 0 cur_len += 1 # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration del outputs if streamer is not None: streamer.end() if return_dict_in_generate: return HiggsAudioGenerationOutput( sequences=input_ids, audio_sequences=audio_sequences, scores=scores, logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) else: return input_ids, audio_sequences @torch.inference_mode() def generate( self, input_ids: Optional[torch.LongTensor] = None, audio_features: Optional[torch.FloatTensor] = None, audio_feature_attention_mask: Optional[torch.BoolTensor] = None, audio_in_ids: Optional[torch.LongTensor] = None, audio_in_ids_start: Optional[torch.LongTensor] = None, audio_out_ids: Optional[torch.LongTensor] = None, audio_out_ids_start: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, audio_out_bos_token_id: int = None, audio_eos_token_id: int = None, past_key_values_buckets: Optional[OrderedDict[int, Cache]] = None, seed: Optional[int] = None, **kwargs, ): """ The generate function in huggingface generally follows these steps: for sample_step in 1, 2, 3, 4, 5, ... ... """ # Right now, it's a very simplified version of generate, we should revisit this after our model architecture stabilizes. assert input_ids.shape[0] == 1, ( "Currently HiggsAudioModel.generate() only supports batch_size=1. See the implementation of " ) generation_config, kwargs = self._prepare_generation_config(kwargs.pop("generation_config", None), **kwargs) if audio_out_bos_token_id is not None: generation_config.generation_kwargs["audio_out_bos_token_id"] = audio_out_bos_token_id else: try: generation_config.generation_kwargs["audio_out_bos_token_id"] = self.audio_out_bos_token_id except: generation_config.generation_kwargs["audio_out_bos_token_id"] = None if audio_eos_token_id is not None: generation_config.generation_kwargs["audio_eos_token_id"] = audio_eos_token_id else: try: generation_config.generation_kwargs["audio_eos_token_id"] = self.audio_eos_token_id except: generation_config.generation_kwargs["audio_eos_token_id"] = None has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None generation_config.generation_kwargs["ras_win_len"] = kwargs.pop("ras_win_len", None) generation_config.generation_kwargs["ras_win_max_num_repeat"] = kwargs.pop("ras_win_max_num_repeat", 2) # Set generation seed if determinstic generation is required if seed is not None: generation_config.generation_kwargs["seed"] = seed # Store tokenizer in generation config if it is in kwargs without popping it if "tokenizer" in kwargs: generation_config.generation_kwargs["tokenizer_length"] = len(kwargs["tokenizer"]) # input_ids: [bsz, seq_len] # The merging of audio features happens inside the forward path. The input_ids does not need to change. # TODO: prepare the final input embeddings to improve generation performance input_ids_length = input_ids.shape[-1] generation_config = self._prepare_generated_length( generation_config=generation_config, has_default_max_length=has_default_max_length, has_default_min_length=has_default_min_length, model_input_name=None, inputs_tensor=None, input_ids_length=input_ids_length, ) assert generation_config.num_beams == 1, "Currently, we only support beam search with num_beams=1" return_dict_in_generate = generation_config.return_dict_in_generate output_scores = generation_config.output_scores # When attn_implement is spda or flash-attention, it will create causal mask automatically. attention_mask = kwargs.pop("attention_mask", None) return super().generate( input_ids=input_ids, attention_mask=attention_mask, audio_features=audio_features, audio_feature_attention_mask=audio_feature_attention_mask, audio_in_ids=audio_in_ids, audio_in_ids_start=audio_in_ids_start, audio_out_ids=audio_out_ids, audio_out_ids_start=audio_out_ids_start, past_key_values=past_key_values, generation_config=generation_config, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, past_key_values_buckets=past_key_values_buckets, **kwargs, ) def parameter_count_per_component(self): """Count the number of parameters per component in the model. HiggsAudio has the following main components: audio_tower: For mapping audio features to hidden states), llm_embed: The size of embedding layer of the LLM llm_non_embed: The size of non-embedding layer of the LLM audio_adapter: The overall size of additional layers for audio generation """ trainable_stats = { "audio_tower": 0, "llm_embed": 0, "llm_non_embed": 0, "audio_embed": 0, "audio_adapter": 0, "overall": 0, } total_stats = { "audio_tower": 0, "llm_embed": 0, "llm_non_embed": 0, "audio_embed": 0, "audio_adapter": 0, "overall": 0, } total_stats["overall"] = count_parameters(self, trainable_only=False) trainable_stats["overall"] = count_parameters(self, trainable_only=True) for mod in [self.audio_tower]: if mod is not None: total_stats["audio_tower"] += count_parameters(mod, trainable_only=False) trainable_stats["audio_tower"] += count_parameters(mod, trainable_only=True) total_stats["llm_embed"] = count_parameters(self.embed_tokens, trainable_only=False) trainable_stats["llm_embed"] = count_parameters(self.embed_tokens, trainable_only=True) total_stats["audio_embed"] = count_parameters(self.audio_codebook_embeddings, trainable_only=False) trainable_stats["audio_embed"] = count_parameters(self.audio_codebook_embeddings, trainable_only=True) # Calculate number of parameters for LLM for layer in self.layers: if isinstance(layer, HiggsAudioDualFFNDecoderLayer): total_param_count = count_parameters(layer, trainable_only=False) total_trainable_param_count = count_parameters(layer, trainable_only=True) total_stats["llm_non_embed"] += total_param_count trainable_stats["llm_non_embed"] += total_trainable_param_count if not layer.fast_forward: audio_mlp_param_count = count_parameters(layer.audio_mlp, trainable_only=False) audio_mlp_trainable_param_count = count_parameters(layer.audio_mlp, trainable_only=True) audio_norm_param_count = count_parameters( layer.audio_post_attention_layernorm, trainable_only=False ) + count_parameters(layer.audio_input_layernorm, trainable_only=False) audio_norm_trainable_param_count = count_parameters( layer.audio_post_attention_layernorm, trainable_only=True ) + count_parameters(layer.audio_input_layernorm, trainable_only=True) total_stats["llm_non_embed"] -= audio_mlp_param_count + audio_norm_param_count trainable_stats["llm_non_embed"] -= ( audio_mlp_trainable_param_count + audio_norm_trainable_param_count ) total_stats["audio_adapter"] += audio_mlp_param_count + audio_norm_param_count trainable_stats["audio_adapter"] += ( audio_mlp_trainable_param_count + audio_norm_trainable_param_count ) if layer.use_audio_attention: audio_attn_param_count = count_parameters( layer.audio_attn, trainable_only=False ) + count_parameters(layer.audio_post_audio_attn_layer_norm, trainable_only=False) audio_attn_trainable_param_count = count_parameters( layer.audio_attn, trainable_only=True ) + count_parameters(layer.audio_post_audio_attn_layer_norm, trainable_only=True) total_stats["llm_non_embed"] -= audio_attn_param_count trainable_stats["llm_non_embed"] -= audio_attn_trainable_param_count total_stats["audio_adapter"] += audio_attn_param_count trainable_stats["audio_adapter"] += audio_attn_trainable_param_count else: total_stats["llm_non_embed"] += count_parameters(layer, trainable_only=False) trainable_stats["llm_non_embed"] += count_parameters(layer, trainable_only=True) total_stats["llm_non_embed"] += count_parameters(self.norm, trainable_only=False) trainable_stats["llm_non_embed"] += count_parameters(self.norm, trainable_only=True) total_stats["audio_adapter"] += count_parameters(self.audio_decoder_proj.audio_lm_head, trainable_only=False) trainable_stats["audio_adapter"] += count_parameters( self.audio_decoder_proj.audio_lm_head, trainable_only=True ) total_stats["llm_embed"] += count_parameters(self.audio_decoder_proj.text_lm_head, trainable_only=False) trainable_stats["llm_embed"] += count_parameters(self.audio_decoder_proj.text_lm_head, trainable_only=True) other_audio_modules = [self.audio_encoder_proj] if self.use_audio_out_embed_projector: other_audio_modules.append(self.audio_out_embed_projector) for mod in other_audio_modules: if mod is not None: total_stats["audio_adapter"] += count_parameters(mod, trainable_only=False) trainable_stats["audio_adapter"] += count_parameters(mod, trainable_only=True) return {"trainable": trainable_stats, "total": total_stats} def set_skip_audio_tower(self): self.config.skip_audio_tower = True self.config.encode_whisper_embed = False def set_encode_audio_in_tokens(self): self.config.encode_audio_in_tokens = True def freeze_audio_tower(self): if self.audio_tower is not None: for param in self.audio_tower.parameters(): param.requires_grad = False def freeze_audio_encoder_proj(self): if self.audio_encoder_proj is not None: for param in self.audio_encoder_proj.parameters(): param.requires_grad = False def freeze_llm(self, freeze_embed=True, freeze_embed_until_idx: Optional[int] = None): for layer in self.layers: if isinstance(layer, HiggsAudioDualFFNDecoderLayer): for param in layer.self_attn.parameters(): param.requires_grad = False for param in layer.mlp.parameters(): param.requires_grad = False for param in layer.post_attention_layernorm.parameters(): param.requires_grad = False for param in layer.input_layernorm.parameters(): param.requires_grad = False else: for param in layer.parameters(): param.requires_grad = False for param in self.norm.parameters(): param.requires_grad = False if freeze_embed: if freeze_embed_until_idx is None: for param in self.embed_tokens.parameters(): param.requires_grad = False else: assert isinstance(self.embed_tokens, nn.Embedding) self.embed_tokens = PartiallyFrozenEmbedding( original_embedding=self.embed_tokens, freeze_until_idx=freeze_embed_until_idx ) def freeze_text_head(self, freeze_text_head_until_idx: Optional[int] = None): """Freeze the final text head""" if freeze_text_head_until_idx is None: for param in self.audio_decoder_proj.text_lm_head.parameters(): param.requires_grad = False else: assert isinstance(self.audio_decoder_proj.text_lm_head, nn.Linear) self.audio_decoder_proj.text_lm_head = PartiallyFrozenLinear( original_linear=self.audio_decoder_proj.text_lm_head, freeze_until_idx=freeze_text_head_until_idx ) @classmethod def merge_weights_from_checkpoint(cls, checkpoint_dir: str, merged_output_dir: str, *model_args, **kwargs): # For users' convenience, we merge back embedding and text_lm_head if they are splitted splitted_model = super().from_pretrained( checkpoint_dir, *model_args, torch_dtype=torch.bfloat16, device_map="cpu", **{**kwargs, "state_dict": None}, # Prevent auto-loading state_dict ) # Load all safetensor shards state_dict = {} shard_paths = sorted(glob.glob(os.path.join(checkpoint_dir, "*.safetensors"))) for shard_path in shard_paths: shard_dict = load_file(shard_path) # Load each shard state_dict.update(shard_dict) # Merge into a single dict # Merge weights if ( "audio_decoder_proj.text_lm_head.linear_frozen.weight" in state_dict and "audio_decoder_proj.text_lm_head.linear_trainable.weight" in state_dict ): state_dict["audio_decoder_proj.text_lm_head.weight"] = torch.cat( [ state_dict["audio_decoder_proj.text_lm_head.linear_frozen.weight"], state_dict["audio_decoder_proj.text_lm_head.linear_trainable.weight"], ], dim=0, ) del state_dict["audio_decoder_proj.text_lm_head.linear_frozen.weight"] del state_dict["audio_decoder_proj.text_lm_head.linear_trainable.weight"] if ( "embed_tokens.embedding_frozen.weight" in state_dict and "embed_tokens.embedding_trainable.weight" in state_dict ): state_dict["embed_tokens.weight"] = torch.cat( [ state_dict["embed_tokens.embedding_frozen.weight"], state_dict["embed_tokens.embedding_trainable.weight"], ], dim=0, ) del state_dict["embed_tokens.embedding_frozen.weight"] del state_dict["embed_tokens.embedding_trainable.weight"] # Load the final state_dict splitted_model.load_state_dict(state_dict, strict=True) if merged_output_dir: splitted_model.save_pretrained(merged_output_dir, is_main_process=True, state_dict=state_dict) @torch.inference_mode() def capture_model(self, past_key_values: list[Union[Cache, List[torch.FloatTensor]]]) -> None: """Capture CUDA graphs for the model's forward pass with different KV cache lengths. Args: past_key_values: List of KV caches to capture graphs for """ for past_key_value in past_key_values: kv_cache_length = past_key_value.get_max_cache_shape() # We capture two graphs, one for decoding audio tokens and one for decoding text tokens for is_decoding_audio_token in [True, False]: runner = CUDAGraphRunner(self._forward_core) # Create dummy inputs for graph capture batch_size = 1 hidden_dim = self.config.hidden_size hidden_states = torch.zeros( (batch_size, 1, hidden_dim), dtype=None, device=self.device ) causal_mask = torch.ones( (batch_size, 1, 1, kv_cache_length), dtype=None, device=self.device ) position_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=self.device) audio_discrete_codes_mask = torch.tensor( [[is_decoding_audio_token]], dtype=torch.bool, device=self.device ) cache_position = torch.tensor([kv_cache_length - 1], dtype=torch.long, device=self.device) audio_attention_mask = torch.ones_like(causal_mask) fast_forward_attention_mask = torch.ones_like(causal_mask) runner.capture( hidden_states=hidden_states, causal_mask=causal_mask, position_ids=position_ids, audio_discrete_codes_mask=audio_discrete_codes_mask, cache_position=cache_position, past_key_values=past_key_value, use_cache=True, audio_attention_mask=audio_attention_mask, fast_forward_attention_mask=fast_forward_attention_mask, output_attentions=False, output_hidden_states=False, is_decoding_audio_token=is_decoding_audio_token, is_using_cuda_graph=True, #stream=torch.cuda.Stream(device=self.device), ) self.decode_graph_runners[kv_cache_length][is_decoding_audio_token] = runner