removed test files

This commit is contained in:
Yousef Rafat 2025-09-05 23:55:07 +03:00
parent 254622d7c6
commit df4b6a26d1
2 changed files with 0 additions and 2941 deletions

View File

@ -1,560 +0,0 @@
import asyncio
import base64
import math
import torch
import numpy as np
from io import BytesIO
from dataclasses import dataclass
from typing import List, Optional, Union, Any
from copy import deepcopy
from transformers import AutoTokenizer
from transformers.cache_utils import StaticCache
from dataclasses import asdict
from loguru import logger
import librosa
import os
from tokenizer import HiggsAudioTokenizer
import json
from huggingface_hub import snapshot_download
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import CONFIG_MAPPING
from model import HiggsAudioConfig
class HiggsAudioEncoderConfig(PretrainedConfig):
"""Configuration of the Audio encoder in Higgs-Audio."""
model_type = "higgs_audio_encoder"
def __init__(
self,
num_mel_bins=128,
encoder_layers=32,
encoder_attention_heads=20,
encoder_ffn_dim=5120,
encoder_layerdrop=0.0,
d_model=1280,
dropout=0.0,
attention_dropout=0.0,
activation_function="gelu",
activation_dropout=0.0,
scale_embedding=False,
init_std=0.02,
max_source_positions=1500,
pad_token_id=128001,
**kwargs,
):
super().__init__(**kwargs)
self.num_mel_bins = num_mel_bins
self.d_model = d_model
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.encoder_ffn_dim = encoder_ffn_dim
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_function = activation_function
self.activation_dropout = activation_dropout
self.encoder_layerdrop = encoder_layerdrop
self.num_hidden_layers = encoder_layers
self.init_std = init_std
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.max_source_positions = max_source_positions
self.pad_token_id = pad_token_id
class HiggsAudioConfig(PretrainedConfig):
model_type = "higgs_audio"
is_composition = True
def __init__(
self,
text_config=None,
audio_encoder_config=None,
audio_tokenizer_config=None,
audio_adapter_type="stack",
audio_embed_avg=False,
audio_ffn_hidden_size=4096,
audio_ffn_intermediate_size=14336,
audio_dual_ffn_layers=None,
audio_decoder_proj_num_layers=0,
encode_whisper_embed=True,
encode_audio_in_tokens=False,
use_delay_pattern=False,
skip_audio_tower=False,
use_audio_out_embed_projector=False,
use_audio_out_self_attention=False,
use_rq_transformer=False,
rq_transformer_hidden_size=None,
rq_transformer_intermediate_size=None,
rq_transformer_num_attention_heads=None,
rq_transformer_num_key_value_heads=None,
rq_transformer_num_hidden_layers=3,
audio_num_codebooks=12,
audio_codebook_size=1024,
audio_stream_bos_id=1024,
audio_stream_eos_id=1025,
audio_bos_token="<|audio_bos|>",
audio_eos_token="<|audio_eos|>",
audio_out_bos_token="<|audio_out_bos|>",
audio_in_token="<|AUDIO|>",
audio_out_token="<|AUDIO_OUT|>",
audio_in_token_idx=128015,
audio_out_token_idx=128016,
pad_token_id=128001,
audio_out_bos_token_id=128013,
audio_eos_token_id=128012,
**kwargs,
):
if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
text_config = CONFIG_MAPPING["llama"]()
assert audio_adapter_type in [
"stack",
"dual_ffn",
"dual_ffn_fast_forward",
], f"Invalid audio adapter type: {audio_adapter_type}"
if audio_adapter_type.startswith("dual_ffn"):
assert audio_dual_ffn_layers is not None, (
"audio_dual_ffn_layers must be specified when using dual_ffn adapter."
)
self.text_config = text_config
self.audio_encoder_config = audio_encoder_config
self.audio_tokenizer_config = audio_tokenizer_config
self.audio_adapter_type = audio_adapter_type
self.audio_embed_avg = audio_embed_avg
self.audio_ffn_hidden_size = audio_ffn_hidden_size
self.audio_ffn_intermediate_size = audio_ffn_intermediate_size
self.audio_dual_ffn_layers = audio_dual_ffn_layers
self.audio_decoder_proj_num_layers = audio_decoder_proj_num_layers
self.encode_whisper_embed = encode_whisper_embed
self.encode_audio_in_tokens = encode_audio_in_tokens
self.use_delay_pattern = use_delay_pattern
self.skip_audio_tower = skip_audio_tower
self.use_audio_out_embed_projector = use_audio_out_embed_projector
self.use_audio_out_self_attention = use_audio_out_self_attention
self.use_rq_transformer = use_rq_transformer
if self.use_rq_transformer:
assert not self.use_delay_pattern, "Delay pattern is not supported if you turned on RQ-Transformer!"
self.rq_transformer_hidden_size = rq_transformer_hidden_size
self.rq_transformer_intermediate_size = rq_transformer_intermediate_size
self.rq_transformer_num_attention_heads = rq_transformer_num_attention_heads
self.rq_transformer_num_key_value_heads = rq_transformer_num_key_value_heads
self.rq_transformer_num_hidden_layers = rq_transformer_num_hidden_layers
if use_rq_transformer:
# For RQ-Transformer, we set the hidden_size to the same as the text model's hidden size if it is not specified.
if self.rq_transformer_hidden_size is None:
self.rq_transformer_hidden_size = text_config.hidden_size
assert self.rq_transformer_hidden_size % 128 == 0
if self.rq_transformer_intermediate_size is None:
self.rq_transformer_intermediate_size = text_config.intermediate_size
if self.rq_transformer_num_attention_heads is None:
self.rq_transformer_num_attention_heads = self.rq_transformer_hidden_size // 128
if self.rq_transformer_num_key_value_heads is None:
self.rq_transformer_num_key_value_heads = self.rq_transformer_hidden_size // 128 // 4
assert self.rq_transformer_hidden_size % self.rq_transformer_num_attention_heads == 0
assert self.rq_transformer_hidden_size % self.rq_transformer_num_key_value_heads == 0
self.audio_num_codebooks = audio_num_codebooks
self.audio_codebook_size = audio_codebook_size
self.audio_bos_token = audio_bos_token
self.audio_eos_token = audio_eos_token
self.audio_out_bos_token = audio_out_bos_token
self.audio_in_token = audio_in_token
self.audio_out_token = audio_out_token
self.audio_in_token_idx = audio_in_token_idx
self.audio_out_token_idx = audio_out_token_idx
self.audio_stream_bos_id = audio_stream_bos_id
self.audio_stream_eos_id = audio_stream_eos_id
self.audio_out_bos_token_id = audio_out_bos_token_id
self.audio_eos_token_id = audio_eos_token_id
super().__init__(**kwargs)
self.pad_token_id = pad_token_id
from model import HiggsAudioModel
from preprocess import HiggsAudioSampleCollator, ChatMLSample, ChatMLDatasetSample, prepare_chatml_sample, Message
def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"):
is_local = os.path.exists(tokenizer_name_or_path)
if not is_local:
tokenizer_path = snapshot_download(tokenizer_name_or_path)
else:
tokenizer_path = tokenizer_name_or_path
config_path = os.path.join(tokenizer_path, "config.json")
model_path = os.path.join(tokenizer_path, "model.pth")
config = json.load(open(config_path))
model = HiggsAudioTokenizer(
**config,
device=device,
)
parameter_dict = torch.load(model_path, map_location=device)
model.load_state_dict(parameter_dict, strict=False)
model.to(device)
model.eval()
return model
def revert_delay_pattern(data):
"""Convert samples encoded with delay pattern back to the original form.
Args:
data (:obj:`torch.Tensor`):
The data with delay pattern applied. It will have shape (num_codebooks, seq_len + num_codebooks - 1).
Returns:
ret (:obj:`torch.Tensor`):
Recovered data with delay pattern removed. It will have shape (num_codebooks, seq_len).
"""
assert len(data.shape) == 2
out_l = []
num_codebooks = data.shape[0]
for i in range(num_codebooks):
out_l.append(data[i : (i + 1), i : (data.shape[1] - num_codebooks + 1 + i)])
return torch.cat(out_l, dim=0)
@dataclass
class HiggsAudioStreamerDelta:
"""Represents a chunk of generated content, either text or audio tokens."""
text: Optional[str] = None
text_tokens: Optional[torch.Tensor] = None
audio_tokens: Optional[torch.Tensor] = None
finish_reason: Optional[str] = None
@dataclass
class HiggsAudioResponse:
audio: Optional[np.ndarray] = None
generated_audio_tokens: Optional[np.ndarray] = None
sampling_rate: Optional[int] = None
generated_text: str = ""
generated_text_tokens: Optional[np.ndarray] = None
usage: Optional[dict] = None
class HiggsAudioServeEngine:
def __init__(
self,
model_name_or_path: str,
audio_tokenizer_name_or_path: str,
tokenizer_name_or_path: Optional[str] = None,
device: str = "cuda",
torch_dtype: Union[torch.dtype, str] = torch.float16,
kv_cache_lengths: List[int] = [1024, 4096, 8192], # Multiple KV cache sizes
):
self.device = device
self.model_name_or_path = model_name_or_path
self.torch_dtype = torch_dtype
torch.set_default_device("cuda")
# Initialize model and tokenizer
config = HiggsAudioConfig.from_pretrained(
"bosonai/higgs-audio-v2-generation-3B-base",
#trust_remote_code=True
)
#config.num_hidden_layers = config.num_hidden_layers // 2
# ---- Audio Config ----
#config.audio_dual_ffn_layers = config.audio_dual_ffn_layers[:12]
#config.audio_ffn_hidden_size //= 2 # 3072 → 1536
#config.audio_ffn_intermediate_size //= 2 # 8192 → 4096
# ---- Text Config ----
#config.text_config.hidden_size //= 2 # 3072 → 1536
#config.text_config.intermediate_size //= 2 # 8192 → 4096
#config.text_config.num_attention_heads //= 2 # 24 → 12
#config.text_config.num_key_value_heads //= 2 # 8 → 4
#config.text_config.num_hidden_layers //= 2 # 28 → 14
#config.text_config.head_dim //= 2 # 128 → 64
# ---- Shared ----
#config.hidden_size //= 2 # 3072 → 1536
self.model = HiggsAudioModel.from_pretrained(model_name_or_path, torch_dtype = torch_dtype, config = HiggsAudioConfig.from_pretrained(model_name_or_path))#(config = config, device = device, operations = torch.nn, dtype = torch_dtype)
print(self.model.device)
self.model.config = config
logger.info(f"Loaded model from {model_name_or_path}, dtype: {self.model.dtype}")
if tokenizer_name_or_path is None:
tokenizer_name_or_path = model_name_or_path
logger.info(f"Loading tokenizer from {tokenizer_name_or_path}")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
logger.info(f"Initializing Higgs Audio Tokenizer")
self.audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer_name_or_path, device=device)
self.audio_num_codebooks = self.model.config.audio_num_codebooks
self.audio_codebook_size = self.model.config.audio_codebook_size
self.audio_tokenizer_tps = self.audio_tokenizer.tps
self.samples_per_token = int(self.audio_tokenizer.sampling_rate // self.audio_tokenizer_tps)
self.hamming_window_len = 2 * self.audio_num_codebooks * self.samples_per_token
# Set the audio special tokens
# Prepare KV caches for different lengths
cache_config = deepcopy(self.model.config.text_config)
cache_config.num_hidden_layers = self.model.config.text_config.num_hidden_layers
if self.model.config.audio_dual_ffn_layers:
cache_config.num_hidden_layers += len(self.model.config.audio_dual_ffn_layers)
# A list of KV caches for different lengths
self.kv_caches = {
length: StaticCache(
config=cache_config,
max_batch_size=1,
max_cache_len=length,
device=self.model.device,
dtype=self.model.dtype,
)
for length in sorted(kv_cache_lengths)
}
# Reuse collator to prepare inference samples
self.collator = HiggsAudioSampleCollator(
audio_in_token_id=self.model.config.audio_in_token_idx,
audio_out_token_id=self.model.config.audio_out_token_idx,
audio_stream_bos_id=self.model.config.audio_stream_bos_id,
audio_stream_eos_id=self.model.config.audio_stream_eos_id,
pad_token_id=self.model.config.pad_token_id,
return_audio_in_tokens=False,
use_delay_pattern=self.model.config.use_delay_pattern,
audio_num_codebooks=self.model.config.audio_num_codebooks,
round_to=1,
)
# Capture CUDA graphs for each KV cache length
#if device == "cuda":
# logger.info(f"Capturing CUDA graphs for each KV cache length")
# self.model.capture_model(self.kv_caches.values())
def _prepare_inputs(self, chat_ml_sample: ChatMLSample, force_audio_gen: bool = False):
input_tokens, audio_contents, _ = prepare_chatml_sample(
chat_ml_sample,
self.tokenizer,
)
postfix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
if force_audio_gen:
postfix += "<|audio_out_bos|>"
postfix = self.tokenizer.encode(postfix, add_special_tokens=False)
input_tokens.extend(postfix)
# Configure the audio inputs
audio_ids_l = []
for audio_content in audio_contents:
if audio_content.audio_url not in ["placeholder", ""]:
raw_audio, _ = librosa.load(audio_content.audio_url, sr=self.audio_tokenizer.sampling_rate)
elif audio_content.raw_audio is not None:
raw_audio, _ = librosa.load(
BytesIO(base64.b64decode(audio_content.raw_audio)), sr=self.audio_tokenizer.sampling_rate
)
else:
raw_audio = None
if raw_audio is not None:
audio_ids = self.audio_tokenizer.encode(raw_audio, self.audio_tokenizer.sampling_rate)
audio_ids_l.append(audio_ids.squeeze(0).cpu())
if len(audio_ids_l) > 0:
audio_ids_start = torch.tensor(
np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
dtype=torch.long,
device=self.device,
)[0:-1]
audio_ids_concat = torch.cat(audio_ids_l, dim=1)
else:
audio_ids_start = None
audio_ids_concat = None
sample = ChatMLDatasetSample(
input_ids=torch.LongTensor(input_tokens),
label_ids=None,
audio_ids_concat=audio_ids_concat,
audio_ids_start=audio_ids_start,
audio_waveforms_concat=None,
audio_waveforms_start=None,
audio_sample_rate=None,
audio_speaker_indices=None,
)
data = self.collator([sample])
inputs = asdict(data)
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.model.device)
return inputs
def _prepare_kv_caches(self):
for kv_cache in self.kv_caches.values():
kv_cache.reset()
def generate(
self,
chat_ml_sample: ChatMLSample,
max_new_tokens: int,
temperature: float = 0.7,
top_k: Optional[int] = None,
top_p: float = 0.95,
stop_strings: Optional[List[str]] = None,
force_audio_gen: bool = False,
ras_win_len: Optional[int] = 7,
ras_win_max_num_repeat: int = 2,
seed: Optional[int] = None,
):
# Default stop strings
if stop_strings is None:
stop_strings = ["<|end_of_text|>", "<|eot_id|>"]
if ras_win_len is not None and ras_win_len <= 0:
ras_win_len = None
with torch.no_grad():
inputs = self._prepare_inputs(chat_ml_sample, force_audio_gen=force_audio_gen)
prompt_token_ids = inputs["input_ids"][0].cpu().numpy()
self._prepare_kv_caches()
from autoregressive_sampling import auto_sample
outputs = auto_sample(
self.model,
**inputs,
max_new_tokens=max_new_tokens,
use_cache=True,
stop_strings=stop_strings,
tokenizer=self.tokenizer,
do_sample=False if temperature == 0.0 else True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
past_key_values_buckets=self.kv_caches,
ras_win_len=ras_win_len,
ras_win_max_num_repeat=ras_win_max_num_repeat,
seed=seed,
)
if len(outputs) > 0:
wv_list = []
for output_audio in outputs:
vq_code = revert_delay_pattern(output_audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
wv_list.append(wv_numpy)
wv_numpy = torch.cat(wv_list)
wv_numpy = wv_numpy.cpu().numpy()
else:
wv_numpy = None
# We only support one request at a time now
#generated_text_tokens = outputs[0][0].cpu().numpy()[len(prompt_token_ids) :]
#generated_text = self.tokenizer.decode(generated_text_tokens)
#print(generated_text)
generated_audio_tokens = outputs[0].cpu().numpy()
return HiggsAudioResponse(
audio=wv_numpy,
generated_audio_tokens=generated_audio_tokens,
sampling_rate=self.audio_tokenizer.sampling_rate,
#generated_text=generated_text,
#generated_text_tokens=generated_text_tokens,
usage={
"prompt_tokens": prompt_token_ids.shape[0],
# "completion_tokens": generated_text_tokens.shape[0] + generated_audio_tokens.shape[1],
"total_tokens": (
prompt_token_ids.shape[0] #+ generated_text_tokens.shape[0] + generated_audio_tokens.shape[1]
),
"cached_tokens": 0,
},
)
def get_zero_shot_input_sample():
system_prompt = (
"Generate audio following instruction.\n\n<|scene_desc_start|>\nSPEAKER0: british accent\n<|scene_desc_end|>"
)
messages = [
Message(
role="system",
content=system_prompt,
),
Message(
role="user",
content="Hey, everyone! Welcome back to Tech Talk Tuesdays.\n"
"It's your host, Alex, and today, we're diving into a topic that's become absolutely crucial in the tech world — deep learning.\n"
"And let's be honest, if you've been even remotely connected to tech, AI, or machine learning lately, you know that deep learning is everywhere.",
),
]
chat_ml_sample = ChatMLSample(messages=messages)
return chat_ml_sample
def _update_model_kwargs_for_generation(
self,
outputs,
model_kwargs: dict[str, Any],
num_new_tokens: int = 1,
) -> dict[str, Any]:
# past_key_values will be the standard name for kv cache naming
model_kwargs["past_key_values"] = outputs.past_key_values
# thinking above removing token_type_ids
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
if model_kwargs.get("use_cache", True):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
else:
past_positions = model_kwargs.pop("cache_position")
new_positions = torch.arange(
past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype
).to(past_positions.device)
model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
return model_kwargs
MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer"
def main():
input_sample = get_zero_shot_input_sample()
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
torch.manual_seed(2025)
torch.cuda.manual_seed_all(2025)
torch.cuda.manual_seed(2025)
serve_engine = HiggsAudioServeEngine(
MODEL_PATH,
AUDIO_TOKENIZER_PATH,
device=device,
)
import time
logger.info("Starting generation...")
start_time = time.time()
output: HiggsAudioResponse = serve_engine.generate(
chat_ml_sample=input_sample,
max_new_tokens=1024,
top_k=50,
stop_strings=["<|end_of_text|>", "<|eot_id|>"],
)
elapsed_time = time.time() - start_time
logger.info(f"Generation time: {elapsed_time:.2f} seconds")
print(output)
import torchaudio
torchaudio.save(f"output_.wav", torch.from_numpy(output.audio)[None, :], output.sampling_rate)
logger.info(f"Generated text:\n{output.generated_text}")
logger.info(f"Saved audio to output_.wav")
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff