mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 16:50:57 +08:00
removed test files
This commit is contained in:
parent
254622d7c6
commit
df4b6a26d1
@ -1,560 +0,0 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union, Any
|
||||
from copy import deepcopy
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.cache_utils import StaticCache
|
||||
from dataclasses import asdict
|
||||
from loguru import logger
|
||||
import librosa
|
||||
import os
|
||||
from tokenizer import HiggsAudioTokenizer
|
||||
|
||||
import json
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from model import HiggsAudioConfig
|
||||
|
||||
class HiggsAudioEncoderConfig(PretrainedConfig):
|
||||
"""Configuration of the Audio encoder in Higgs-Audio."""
|
||||
|
||||
model_type = "higgs_audio_encoder"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_mel_bins=128,
|
||||
encoder_layers=32,
|
||||
encoder_attention_heads=20,
|
||||
encoder_ffn_dim=5120,
|
||||
encoder_layerdrop=0.0,
|
||||
d_model=1280,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
activation_function="gelu",
|
||||
activation_dropout=0.0,
|
||||
scale_embedding=False,
|
||||
init_std=0.02,
|
||||
max_source_positions=1500,
|
||||
pad_token_id=128001,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.num_mel_bins = num_mel_bins
|
||||
self.d_model = d_model
|
||||
self.encoder_layers = encoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_function = activation_function
|
||||
self.activation_dropout = activation_dropout
|
||||
self.encoder_layerdrop = encoder_layerdrop
|
||||
self.num_hidden_layers = encoder_layers
|
||||
self.init_std = init_std
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
self.max_source_positions = max_source_positions
|
||||
self.pad_token_id = pad_token_id
|
||||
|
||||
|
||||
class HiggsAudioConfig(PretrainedConfig):
|
||||
|
||||
model_type = "higgs_audio"
|
||||
is_composition = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config=None,
|
||||
audio_encoder_config=None,
|
||||
audio_tokenizer_config=None,
|
||||
audio_adapter_type="stack",
|
||||
audio_embed_avg=False,
|
||||
audio_ffn_hidden_size=4096,
|
||||
audio_ffn_intermediate_size=14336,
|
||||
audio_dual_ffn_layers=None,
|
||||
audio_decoder_proj_num_layers=0,
|
||||
encode_whisper_embed=True,
|
||||
encode_audio_in_tokens=False,
|
||||
use_delay_pattern=False,
|
||||
skip_audio_tower=False,
|
||||
use_audio_out_embed_projector=False,
|
||||
use_audio_out_self_attention=False,
|
||||
use_rq_transformer=False,
|
||||
rq_transformer_hidden_size=None,
|
||||
rq_transformer_intermediate_size=None,
|
||||
rq_transformer_num_attention_heads=None,
|
||||
rq_transformer_num_key_value_heads=None,
|
||||
rq_transformer_num_hidden_layers=3,
|
||||
audio_num_codebooks=12,
|
||||
audio_codebook_size=1024,
|
||||
audio_stream_bos_id=1024,
|
||||
audio_stream_eos_id=1025,
|
||||
audio_bos_token="<|audio_bos|>",
|
||||
audio_eos_token="<|audio_eos|>",
|
||||
audio_out_bos_token="<|audio_out_bos|>",
|
||||
audio_in_token="<|AUDIO|>",
|
||||
audio_out_token="<|AUDIO_OUT|>",
|
||||
audio_in_token_idx=128015,
|
||||
audio_out_token_idx=128016,
|
||||
pad_token_id=128001,
|
||||
audio_out_bos_token_id=128013,
|
||||
audio_eos_token_id=128012,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
|
||||
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
elif text_config is None:
|
||||
text_config = CONFIG_MAPPING["llama"]()
|
||||
|
||||
assert audio_adapter_type in [
|
||||
"stack",
|
||||
"dual_ffn",
|
||||
"dual_ffn_fast_forward",
|
||||
], f"Invalid audio adapter type: {audio_adapter_type}"
|
||||
if audio_adapter_type.startswith("dual_ffn"):
|
||||
assert audio_dual_ffn_layers is not None, (
|
||||
"audio_dual_ffn_layers must be specified when using dual_ffn adapter."
|
||||
)
|
||||
self.text_config = text_config
|
||||
self.audio_encoder_config = audio_encoder_config
|
||||
self.audio_tokenizer_config = audio_tokenizer_config
|
||||
self.audio_adapter_type = audio_adapter_type
|
||||
self.audio_embed_avg = audio_embed_avg
|
||||
self.audio_ffn_hidden_size = audio_ffn_hidden_size
|
||||
self.audio_ffn_intermediate_size = audio_ffn_intermediate_size
|
||||
self.audio_dual_ffn_layers = audio_dual_ffn_layers
|
||||
self.audio_decoder_proj_num_layers = audio_decoder_proj_num_layers
|
||||
self.encode_whisper_embed = encode_whisper_embed
|
||||
self.encode_audio_in_tokens = encode_audio_in_tokens
|
||||
self.use_delay_pattern = use_delay_pattern
|
||||
self.skip_audio_tower = skip_audio_tower
|
||||
self.use_audio_out_embed_projector = use_audio_out_embed_projector
|
||||
self.use_audio_out_self_attention = use_audio_out_self_attention
|
||||
|
||||
self.use_rq_transformer = use_rq_transformer
|
||||
|
||||
if self.use_rq_transformer:
|
||||
assert not self.use_delay_pattern, "Delay pattern is not supported if you turned on RQ-Transformer!"
|
||||
self.rq_transformer_hidden_size = rq_transformer_hidden_size
|
||||
self.rq_transformer_intermediate_size = rq_transformer_intermediate_size
|
||||
self.rq_transformer_num_attention_heads = rq_transformer_num_attention_heads
|
||||
self.rq_transformer_num_key_value_heads = rq_transformer_num_key_value_heads
|
||||
self.rq_transformer_num_hidden_layers = rq_transformer_num_hidden_layers
|
||||
|
||||
if use_rq_transformer:
|
||||
# For RQ-Transformer, we set the hidden_size to the same as the text model's hidden size if it is not specified.
|
||||
if self.rq_transformer_hidden_size is None:
|
||||
self.rq_transformer_hidden_size = text_config.hidden_size
|
||||
assert self.rq_transformer_hidden_size % 128 == 0
|
||||
if self.rq_transformer_intermediate_size is None:
|
||||
self.rq_transformer_intermediate_size = text_config.intermediate_size
|
||||
if self.rq_transformer_num_attention_heads is None:
|
||||
self.rq_transformer_num_attention_heads = self.rq_transformer_hidden_size // 128
|
||||
if self.rq_transformer_num_key_value_heads is None:
|
||||
self.rq_transformer_num_key_value_heads = self.rq_transformer_hidden_size // 128 // 4
|
||||
assert self.rq_transformer_hidden_size % self.rq_transformer_num_attention_heads == 0
|
||||
assert self.rq_transformer_hidden_size % self.rq_transformer_num_key_value_heads == 0
|
||||
|
||||
self.audio_num_codebooks = audio_num_codebooks
|
||||
self.audio_codebook_size = audio_codebook_size
|
||||
self.audio_bos_token = audio_bos_token
|
||||
self.audio_eos_token = audio_eos_token
|
||||
self.audio_out_bos_token = audio_out_bos_token
|
||||
self.audio_in_token = audio_in_token
|
||||
self.audio_out_token = audio_out_token
|
||||
self.audio_in_token_idx = audio_in_token_idx
|
||||
self.audio_out_token_idx = audio_out_token_idx
|
||||
self.audio_stream_bos_id = audio_stream_bos_id
|
||||
self.audio_stream_eos_id = audio_stream_eos_id
|
||||
self.audio_out_bos_token_id = audio_out_bos_token_id
|
||||
self.audio_eos_token_id = audio_eos_token_id
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.pad_token_id = pad_token_id
|
||||
|
||||
from model import HiggsAudioModel
|
||||
from preprocess import HiggsAudioSampleCollator, ChatMLSample, ChatMLDatasetSample, prepare_chatml_sample, Message
|
||||
|
||||
def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"):
|
||||
is_local = os.path.exists(tokenizer_name_or_path)
|
||||
if not is_local:
|
||||
tokenizer_path = snapshot_download(tokenizer_name_or_path)
|
||||
else:
|
||||
tokenizer_path = tokenizer_name_or_path
|
||||
config_path = os.path.join(tokenizer_path, "config.json")
|
||||
model_path = os.path.join(tokenizer_path, "model.pth")
|
||||
config = json.load(open(config_path))
|
||||
model = HiggsAudioTokenizer(
|
||||
**config,
|
||||
device=device,
|
||||
)
|
||||
parameter_dict = torch.load(model_path, map_location=device)
|
||||
model.load_state_dict(parameter_dict, strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def revert_delay_pattern(data):
|
||||
"""Convert samples encoded with delay pattern back to the original form.
|
||||
|
||||
Args:
|
||||
data (:obj:`torch.Tensor`):
|
||||
The data with delay pattern applied. It will have shape (num_codebooks, seq_len + num_codebooks - 1).
|
||||
|
||||
Returns:
|
||||
ret (:obj:`torch.Tensor`):
|
||||
Recovered data with delay pattern removed. It will have shape (num_codebooks, seq_len).
|
||||
"""
|
||||
assert len(data.shape) == 2
|
||||
out_l = []
|
||||
num_codebooks = data.shape[0]
|
||||
for i in range(num_codebooks):
|
||||
out_l.append(data[i : (i + 1), i : (data.shape[1] - num_codebooks + 1 + i)])
|
||||
return torch.cat(out_l, dim=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HiggsAudioStreamerDelta:
|
||||
"""Represents a chunk of generated content, either text or audio tokens."""
|
||||
|
||||
text: Optional[str] = None
|
||||
text_tokens: Optional[torch.Tensor] = None
|
||||
audio_tokens: Optional[torch.Tensor] = None
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class HiggsAudioResponse:
|
||||
audio: Optional[np.ndarray] = None
|
||||
generated_audio_tokens: Optional[np.ndarray] = None
|
||||
sampling_rate: Optional[int] = None
|
||||
generated_text: str = ""
|
||||
generated_text_tokens: Optional[np.ndarray] = None
|
||||
usage: Optional[dict] = None
|
||||
|
||||
|
||||
class HiggsAudioServeEngine:
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
audio_tokenizer_name_or_path: str,
|
||||
tokenizer_name_or_path: Optional[str] = None,
|
||||
device: str = "cuda",
|
||||
torch_dtype: Union[torch.dtype, str] = torch.float16,
|
||||
kv_cache_lengths: List[int] = [1024, 4096, 8192], # Multiple KV cache sizes
|
||||
):
|
||||
|
||||
self.device = device
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.torch_dtype = torch_dtype
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
# Initialize model and tokenizer
|
||||
config = HiggsAudioConfig.from_pretrained(
|
||||
"bosonai/higgs-audio-v2-generation-3B-base",
|
||||
#trust_remote_code=True
|
||||
)
|
||||
#config.num_hidden_layers = config.num_hidden_layers // 2
|
||||
# ---- Audio Config ----
|
||||
#config.audio_dual_ffn_layers = config.audio_dual_ffn_layers[:12]
|
||||
#config.audio_ffn_hidden_size //= 2 # 3072 → 1536
|
||||
#config.audio_ffn_intermediate_size //= 2 # 8192 → 4096
|
||||
|
||||
# ---- Text Config ----
|
||||
#config.text_config.hidden_size //= 2 # 3072 → 1536
|
||||
#config.text_config.intermediate_size //= 2 # 8192 → 4096
|
||||
#config.text_config.num_attention_heads //= 2 # 24 → 12
|
||||
#config.text_config.num_key_value_heads //= 2 # 8 → 4
|
||||
#config.text_config.num_hidden_layers //= 2 # 28 → 14
|
||||
#config.text_config.head_dim //= 2 # 128 → 64
|
||||
|
||||
# ---- Shared ----
|
||||
#config.hidden_size //= 2 # 3072 → 1536
|
||||
|
||||
self.model = HiggsAudioModel.from_pretrained(model_name_or_path, torch_dtype = torch_dtype, config = HiggsAudioConfig.from_pretrained(model_name_or_path))#(config = config, device = device, operations = torch.nn, dtype = torch_dtype)
|
||||
print(self.model.device)
|
||||
self.model.config = config
|
||||
logger.info(f"Loaded model from {model_name_or_path}, dtype: {self.model.dtype}")
|
||||
|
||||
if tokenizer_name_or_path is None:
|
||||
tokenizer_name_or_path = model_name_or_path
|
||||
logger.info(f"Loading tokenizer from {tokenizer_name_or_path}")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
|
||||
|
||||
logger.info(f"Initializing Higgs Audio Tokenizer")
|
||||
self.audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer_name_or_path, device=device)
|
||||
|
||||
self.audio_num_codebooks = self.model.config.audio_num_codebooks
|
||||
self.audio_codebook_size = self.model.config.audio_codebook_size
|
||||
self.audio_tokenizer_tps = self.audio_tokenizer.tps
|
||||
self.samples_per_token = int(self.audio_tokenizer.sampling_rate // self.audio_tokenizer_tps)
|
||||
self.hamming_window_len = 2 * self.audio_num_codebooks * self.samples_per_token
|
||||
# Set the audio special tokens
|
||||
|
||||
# Prepare KV caches for different lengths
|
||||
cache_config = deepcopy(self.model.config.text_config)
|
||||
cache_config.num_hidden_layers = self.model.config.text_config.num_hidden_layers
|
||||
if self.model.config.audio_dual_ffn_layers:
|
||||
cache_config.num_hidden_layers += len(self.model.config.audio_dual_ffn_layers)
|
||||
# A list of KV caches for different lengths
|
||||
self.kv_caches = {
|
||||
length: StaticCache(
|
||||
config=cache_config,
|
||||
max_batch_size=1,
|
||||
max_cache_len=length,
|
||||
device=self.model.device,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
for length in sorted(kv_cache_lengths)
|
||||
}
|
||||
|
||||
# Reuse collator to prepare inference samples
|
||||
self.collator = HiggsAudioSampleCollator(
|
||||
audio_in_token_id=self.model.config.audio_in_token_idx,
|
||||
audio_out_token_id=self.model.config.audio_out_token_idx,
|
||||
audio_stream_bos_id=self.model.config.audio_stream_bos_id,
|
||||
audio_stream_eos_id=self.model.config.audio_stream_eos_id,
|
||||
pad_token_id=self.model.config.pad_token_id,
|
||||
return_audio_in_tokens=False,
|
||||
use_delay_pattern=self.model.config.use_delay_pattern,
|
||||
audio_num_codebooks=self.model.config.audio_num_codebooks,
|
||||
round_to=1,
|
||||
)
|
||||
|
||||
# Capture CUDA graphs for each KV cache length
|
||||
#if device == "cuda":
|
||||
# logger.info(f"Capturing CUDA graphs for each KV cache length")
|
||||
# self.model.capture_model(self.kv_caches.values())
|
||||
|
||||
def _prepare_inputs(self, chat_ml_sample: ChatMLSample, force_audio_gen: bool = False):
|
||||
input_tokens, audio_contents, _ = prepare_chatml_sample(
|
||||
chat_ml_sample,
|
||||
self.tokenizer,
|
||||
)
|
||||
|
||||
postfix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
if force_audio_gen:
|
||||
postfix += "<|audio_out_bos|>"
|
||||
postfix = self.tokenizer.encode(postfix, add_special_tokens=False)
|
||||
input_tokens.extend(postfix)
|
||||
|
||||
# Configure the audio inputs
|
||||
audio_ids_l = []
|
||||
for audio_content in audio_contents:
|
||||
if audio_content.audio_url not in ["placeholder", ""]:
|
||||
raw_audio, _ = librosa.load(audio_content.audio_url, sr=self.audio_tokenizer.sampling_rate)
|
||||
elif audio_content.raw_audio is not None:
|
||||
raw_audio, _ = librosa.load(
|
||||
BytesIO(base64.b64decode(audio_content.raw_audio)), sr=self.audio_tokenizer.sampling_rate
|
||||
)
|
||||
else:
|
||||
raw_audio = None
|
||||
|
||||
if raw_audio is not None:
|
||||
audio_ids = self.audio_tokenizer.encode(raw_audio, self.audio_tokenizer.sampling_rate)
|
||||
audio_ids_l.append(audio_ids.squeeze(0).cpu())
|
||||
|
||||
if len(audio_ids_l) > 0:
|
||||
audio_ids_start = torch.tensor(
|
||||
np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
)[0:-1]
|
||||
audio_ids_concat = torch.cat(audio_ids_l, dim=1)
|
||||
else:
|
||||
audio_ids_start = None
|
||||
audio_ids_concat = None
|
||||
|
||||
sample = ChatMLDatasetSample(
|
||||
input_ids=torch.LongTensor(input_tokens),
|
||||
label_ids=None,
|
||||
audio_ids_concat=audio_ids_concat,
|
||||
audio_ids_start=audio_ids_start,
|
||||
audio_waveforms_concat=None,
|
||||
audio_waveforms_start=None,
|
||||
audio_sample_rate=None,
|
||||
audio_speaker_indices=None,
|
||||
)
|
||||
data = self.collator([sample])
|
||||
inputs = asdict(data)
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
inputs[k] = v.to(self.model.device)
|
||||
|
||||
return inputs
|
||||
|
||||
def _prepare_kv_caches(self):
|
||||
for kv_cache in self.kv_caches.values():
|
||||
kv_cache.reset()
|
||||
|
||||
def generate(
|
||||
self,
|
||||
chat_ml_sample: ChatMLSample,
|
||||
max_new_tokens: int,
|
||||
temperature: float = 0.7,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: float = 0.95,
|
||||
stop_strings: Optional[List[str]] = None,
|
||||
force_audio_gen: bool = False,
|
||||
ras_win_len: Optional[int] = 7,
|
||||
ras_win_max_num_repeat: int = 2,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
|
||||
# Default stop strings
|
||||
if stop_strings is None:
|
||||
stop_strings = ["<|end_of_text|>", "<|eot_id|>"]
|
||||
if ras_win_len is not None and ras_win_len <= 0:
|
||||
ras_win_len = None
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self._prepare_inputs(chat_ml_sample, force_audio_gen=force_audio_gen)
|
||||
prompt_token_ids = inputs["input_ids"][0].cpu().numpy()
|
||||
|
||||
self._prepare_kv_caches()
|
||||
from autoregressive_sampling import auto_sample
|
||||
outputs = auto_sample(
|
||||
self.model,
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
use_cache=True,
|
||||
stop_strings=stop_strings,
|
||||
tokenizer=self.tokenizer,
|
||||
do_sample=False if temperature == 0.0 else True,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
past_key_values_buckets=self.kv_caches,
|
||||
ras_win_len=ras_win_len,
|
||||
ras_win_max_num_repeat=ras_win_max_num_repeat,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
if len(outputs) > 0:
|
||||
wv_list = []
|
||||
for output_audio in outputs:
|
||||
vq_code = revert_delay_pattern(output_audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
|
||||
wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
|
||||
wv_list.append(wv_numpy)
|
||||
wv_numpy = torch.cat(wv_list)
|
||||
wv_numpy = wv_numpy.cpu().numpy()
|
||||
else:
|
||||
wv_numpy = None
|
||||
# We only support one request at a time now
|
||||
#generated_text_tokens = outputs[0][0].cpu().numpy()[len(prompt_token_ids) :]
|
||||
#generated_text = self.tokenizer.decode(generated_text_tokens)
|
||||
#print(generated_text)
|
||||
generated_audio_tokens = outputs[0].cpu().numpy()
|
||||
return HiggsAudioResponse(
|
||||
audio=wv_numpy,
|
||||
generated_audio_tokens=generated_audio_tokens,
|
||||
sampling_rate=self.audio_tokenizer.sampling_rate,
|
||||
#generated_text=generated_text,
|
||||
#generated_text_tokens=generated_text_tokens,
|
||||
usage={
|
||||
"prompt_tokens": prompt_token_ids.shape[0],
|
||||
# "completion_tokens": generated_text_tokens.shape[0] + generated_audio_tokens.shape[1],
|
||||
"total_tokens": (
|
||||
prompt_token_ids.shape[0] #+ generated_text_tokens.shape[0] + generated_audio_tokens.shape[1]
|
||||
),
|
||||
"cached_tokens": 0,
|
||||
},
|
||||
)
|
||||
|
||||
def get_zero_shot_input_sample():
|
||||
system_prompt = (
|
||||
"Generate audio following instruction.\n\n<|scene_desc_start|>\nSPEAKER0: british accent\n<|scene_desc_end|>"
|
||||
)
|
||||
|
||||
messages = [
|
||||
Message(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
),
|
||||
Message(
|
||||
role="user",
|
||||
content="Hey, everyone! Welcome back to Tech Talk Tuesdays.\n"
|
||||
"It's your host, Alex, and today, we're diving into a topic that's become absolutely crucial in the tech world — deep learning.\n"
|
||||
"And let's be honest, if you've been even remotely connected to tech, AI, or machine learning lately, you know that deep learning is everywhere.",
|
||||
),
|
||||
]
|
||||
chat_ml_sample = ChatMLSample(messages=messages)
|
||||
return chat_ml_sample
|
||||
|
||||
def _update_model_kwargs_for_generation(
|
||||
self,
|
||||
outputs,
|
||||
model_kwargs: dict[str, Any],
|
||||
num_new_tokens: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
|
||||
# past_key_values will be the standard name for kv cache naming
|
||||
model_kwargs["past_key_values"] = outputs.past_key_values
|
||||
|
||||
# thinking above removing token_type_ids
|
||||
# update token_type_ids with last value
|
||||
if "token_type_ids" in model_kwargs:
|
||||
token_type_ids = model_kwargs["token_type_ids"]
|
||||
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
||||
|
||||
# update attention mask
|
||||
if "attention_mask" in model_kwargs:
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
||||
)
|
||||
|
||||
if model_kwargs.get("use_cache", True):
|
||||
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
||||
else:
|
||||
past_positions = model_kwargs.pop("cache_position")
|
||||
new_positions = torch.arange(
|
||||
past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype
|
||||
).to(past_positions.device)
|
||||
model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
|
||||
return model_kwargs
|
||||
|
||||
MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
|
||||
AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer"
|
||||
def main():
|
||||
input_sample = get_zero_shot_input_sample()
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using device: {device}")
|
||||
torch.manual_seed(2025)
|
||||
torch.cuda.manual_seed_all(2025)
|
||||
torch.cuda.manual_seed(2025)
|
||||
serve_engine = HiggsAudioServeEngine(
|
||||
MODEL_PATH,
|
||||
AUDIO_TOKENIZER_PATH,
|
||||
device=device,
|
||||
)
|
||||
import time
|
||||
logger.info("Starting generation...")
|
||||
start_time = time.time()
|
||||
output: HiggsAudioResponse = serve_engine.generate(
|
||||
chat_ml_sample=input_sample,
|
||||
max_new_tokens=1024,
|
||||
top_k=50,
|
||||
stop_strings=["<|end_of_text|>", "<|eot_id|>"],
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"Generation time: {elapsed_time:.2f} seconds")
|
||||
print(output)
|
||||
import torchaudio
|
||||
torchaudio.save(f"output_.wav", torch.from_numpy(output.audio)[None, :], output.sampling_rate)
|
||||
logger.info(f"Generated text:\n{output.generated_text}")
|
||||
logger.info(f"Saved audio to output_.wav")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user