ComfyUI/comfy/ldm/higgsv2/preprocess.py
Yousef Rafat 233e4415a1 .
2025-09-06 01:47:53 +03:00

728 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import (
Dict, List, Optional, Union, Tuple, MutableMapping, Any, Mapping, Collection, get_type_hints, get_args, get_origin
)
from dataclasses import dataclass, fields, _FIELDS, _FIELD, _FIELD_INITVAR, is_dataclass
import torch.nn.functional as F
import numpy as np
import torch
import math
from functools import lru_cache
__MAX_SIZE = 2048
def _ceil_to_nearest(n, round_to):
return (n + round_to - 1) // round_to * round_to
@torch.jit.script
def build_delay_pattern_mask(
input_ids: torch.Tensor,
bos_token_id: int,
pad_token_id: int,
):
bsz, num_codebooks, seq_len = input_ids.shape
new_seq_len = seq_len + num_codebooks - 1
input_ids_with_gen_mask = torch.ones((bsz, num_codebooks, new_seq_len), dtype=torch.long, device=input_ids.device)
bos_mask = torch.tril(input_ids_with_gen_mask, -1) > 0
eos_mask = torch.triu(input_ids_with_gen_mask, seq_len) > 0
input_ids_with_gen_mask[bos_mask] = bos_token_id
input_ids_with_gen_mask[(~bos_mask) & (~eos_mask)] = input_ids.reshape(-1)
input_ids = input_ids_with_gen_mask.clone()
input_ids[eos_mask] = pad_token_id
input_ids_with_gen_mask[eos_mask] = -1
return input_ids, input_ids_with_gen_mask
# implementation of dacite's from_dict with only necessary parts for ChatML
@lru_cache(maxsize=None)
def cache(function):
return lru_cache(maxsize=__MAX_SIZE, typed=True)(function)
@cache
def get_fields(data_class):
fields = getattr(data_class, _FIELDS)
return [f for f in fields.values() if f._field_type is _FIELD or f._field_type is _FIELD_INITVAR]
def is_optional(type_) -> bool:
return get_origin(type_) is Union and type(None) in get_args(type_)
def orig(data_class) -> Any:
if is_dataclass(data_class):
return data_class
return get_origin(data_class)
@cache
def extract_generic(type_, defaults: Tuple = ()) -> tuple:
try:
if getattr(type_, "_special", False):
return defaults
if type_.__args__ == ():
return (type_.__args__,)
return type_.__args__ or defaults # type: ignore
except AttributeError:
return defaults
def _build_value_for_collection(collection, data: Any) -> Any:
if isinstance(data, Mapping):
value_type = extract_generic(collection, defaults=(Any, Any))[1]
return {
key: _build_value(type_=value_type, data=value)
for key, value in data.items()
}
elif isinstance(data, Collection) and not isinstance(data, (str, bytes, Mapping)):
item_type = extract_generic(collection, defaults=(Any,))[0]
return [
_build_value(type_=item_type, data=item)
for item in data
]
return data
def _build_value(type_, data) -> Any:
if is_optional(type_) and data is None:
return data
if get_origin(type_) is Union:
data = _build_value_for_union(union=type_, data=data)
elif hasattr(type_, "__origin__"):
data = _build_value_for_collection(collection=type_, data=data)
elif cache(is_dataclass)(orig(type_)) and isinstance(data, Mapping):
data = from_dict(data_class=type_, data=data)
return data
def _build_value_for_union(union: type, data: Any) -> Any:
for inner_type in get_args(union):
if data is None and inner_type is type(None):
return None
try:
return _build_value(inner_type, data)
except Exception:
continue
raise ValueError(f"Cannot match {data!r} to any type in {union}")
def is_instance(value: Any, type_) -> bool:
if type_ is Any:
return True
origin = get_origin(type_)
args = get_args(type_)
if origin is Union:
return any(is_instance(value, arg) for arg in args)
if origin in (list, List):
if not isinstance(value, list):
return False
(elem_type,) = args or (Any,)
return all(is_instance(item, elem_type) for item in value)
if origin in (dict, Dict, Mapping):
if not isinstance(value, dict):
return False
key_type, val_type = args or (Any, Any)
return all(
is_instance(k, key_type) and is_instance(v, val_type)
for k, v in value.items()
)
try:
return isinstance(value, type_)
except TypeError:
return False
def from_dict(data_class, data):
init_values: MutableMapping[str, Any] = {}
post_init_values: MutableMapping[str, Any] = {}
data_class_hints = get_type_hints(data_class)
data_class_fields = cache(get_fields)(data_class)
extra_fields = set(data.keys()) - {f.name for f in data_class_fields}
if extra_fields:
formatted_keys = ", ".join(f'"{key}"' for key in extra_fields)
raise ValueError(f"cannot match {formatted_keys} to any data class field")
for field in data_class_fields:
field_type = data_class_hints[field.name]
key = field.name
if key in data:
try:
value = _build_value(type_=field_type, data=data[key])
except Exception as error:
raise ValueError(error)
if not is_instance(value, field_type):
raise ValueError((
f'wrong value type for field "{field.name}" - should be "{field_type}" '
f'instead of value "{value}" of type "{type(value)}"'
))
init_values[field.name] = value
instance = data_class(**init_values)
for key, value in post_init_values.items():
setattr(instance, key, value)
return instance
def normalize_chinese_punctuation(text):
"""
Convert Chinese (full-width) punctuation marks to English (half-width) equivalents.
"""
# Mapping of Chinese punctuation to English punctuation
chinese_to_english_punct = {
"": ", ", # comma
"": ".", # period
"": ":", # colon
"": ";", # semicolon
"": "?", # question mark
"": "!", # exclamation mark
"": "(", # left parenthesis
"": ")", # right parenthesis
"": "[", # left square bracket
"": "]", # right square bracket
"": "<", # left angle quote
"": ">", # right angle quote
"": '"', # left double quotation
"": '"', # right double quotation
"": "'", # left single quotation
"": "'", # right single quotation
"": ",", # enumeration comma
"": "-", # em dash
"": "...", # ellipsis
"·": ".", # middle dot
"": '"', # left corner bracket
"": '"', # right corner bracket
"": '"', # left double corner bracket
"": '"', # right double corner bracket
}
# Replace each Chinese punctuation with its English counterpart
for zh_punct, en_punct in chinese_to_english_punct.items():
text = text.replace(zh_punct, en_punct)
return text
def transcript_normalize(text: str):
transcript = normalize_chinese_punctuation(text)
transcript = transcript.replace("(", " ")
transcript = transcript.replace(")", " ")
transcript = transcript.replace("°F", " degrees Fahrenheit")
transcript = transcript.replace("°C", " degrees Celsius")
for tag, replacement in [
("[laugh]", "<SE>[Laughter]</SE>"),
("[humming start]", "<SE_s>[Humming]</SE_s>"),
("[humming end]", "<SE_e>[Humming]</SE_e>"),
("[music start]", "<SE_s>[Music]</SE_s>"),
("[music end]", "<SE_e>[Music]</SE_e>"),
("[music]", "<SE>[Music]</SE>"),
("[sing start]", "<SE_s>[Singing]</SE_s>"),
("[sing end]", "<SE_e>[Singing]</SE_e>"),
("[applause]", "<SE>[Applause]</SE>"),
("[cheering]", "<SE>[Cheering]</SE>"),
("[cough]", "<SE>[Cough]</SE>"),
]:
transcript = transcript.replace(tag, replacement)
lines = transcript.split("\n")
transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()])
transcript = transcript.strip()
if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "</SE_e>", "</SE>"]]):
transcript += "."
return transcript
@dataclass
class AudioContent:
audio_url: str
raw_audio: Optional[str] = None
offset: Optional[float] = None
duration: Optional[float] = None
row_id: Optional[int] = None
type: str = "audio"
@dataclass
class TextContent:
text: str
type: str = "text"
@dataclass
class Message:
role: str
content: Union[str, AudioContent, TextContent, List[Union[str, AudioContent, TextContent]]]
recipient: Optional[str] = None
@dataclass
class ChatMLSample:
messages: List[Message]
start_index: Optional[int] = None
misc: Optional[Dict] = None
speaker: Optional[str] = None
def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
try:
if not isinstance(sample, ChatMLSample):
# replacing pd.isna
def is_nan(x):
if isinstance(x, float):
return math.isnan(x)
if isinstance(x, np.generic):
return np.isnan(x)
if isinstance(x, torch.Tensor) and x.numel() == 1:
return torch.isnan(x).item()
return False
if "speaker" in sample and is_nan(sample["speaker"]):
sample["speaker"] = None
if "start_index" in sample and is_nan(sample["start_index"]):
sample["start_index"] = None
if "content" in sample and is_nan(sample["content"]):
sample["content"] = ""
def convert_nan_to_none(obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, float) and math.isnan(obj):
return None
elif isinstance(obj, dict):
return {k: convert_nan_to_none(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [convert_nan_to_none(item) for item in obj]
return obj
clean_sample = convert_nan_to_none(sample)
val_keys = []
for field in fields(ChatMLSample):
if field.name in clean_sample:
val_keys.append(field.name)
clean_sample = {k: clean_sample[k] for k in val_keys}
sample = from_dict(
data_class=ChatMLSample, data=clean_sample,
)
input_tokens = []
audio_contents = []
speaker_id = None
if sample.speaker is not None:
speaker_id = sample.speaker
elif sample.misc is not None:
if "speaker" in sample.misc:
speaker_id = sample.misc["speaker"]
total_m = len(sample.messages)
for turn_id, message in enumerate(sample.messages):
role = message.role
recipient = message.recipient
content = message.content
content_l = []
if isinstance(content, str):
content_l.append(TextContent(text=content))
elif isinstance(content, TextContent):
content_l.append(content)
elif isinstance(content, AudioContent):
content_l.append(content)
elif isinstance(content, list):
for ele in content:
if isinstance(ele, str):
content_l.append(TextContent(text=ele))
else:
content_l.append(ele)
if turn_id == 0:
prefix = f"<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n"
else:
prefix = f"<|start_header_id|>{role}<|end_header_id|>\n\n"
eot_postfix = "<|eot_id|>"
eom_postfix = "<|eom_id|>"
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
input_tokens.extend(prefix_tokens)
if recipient:
assert role == "assistant", "Recipient is only available for assistant role."
recipient_tokens = tokenizer.encode(f"{recipient}<|recipient|>", add_special_tokens=False)
input_tokens.extend(recipient_tokens)
for content in content_l:
if content.type == "text":
text_tokens = tokenizer.encode(content.text, add_special_tokens=False)
input_tokens.extend(text_tokens)
elif content.type == "audio":
audio_contents.append(content)
if role == "user" or role == "system":
text_tokens = tokenizer.encode(
"<|audio_bos|><|AUDIO|><|audio_eos|>",
add_special_tokens=False,
)
input_tokens.extend(text_tokens)
elif role == "assistant":
text_tokens = tokenizer.encode(
"<|audio_out_bos|><|AUDIO_OUT|><|audio_eos|>",
add_special_tokens=False,
)
input_tokens.extend(text_tokens)
next_id = turn_id + 1
if role == "assistant" and next_id != total_m and sample.messages[next_id].role == "assistant":
postfix_tokens = tokenizer.encode(eom_postfix, add_special_tokens=False)
input_tokens.extend(postfix_tokens)
else:
postfix_tokens = tokenizer.encode(eot_postfix, add_special_tokens=False)
input_tokens.extend(postfix_tokens)
return input_tokens, audio_contents, speaker_id
except Exception:
return None, None, None
@dataclass
class HiggsAudioBatchInput:
input_ids: torch.LongTensor # shape (bsz, seq_len).
attention_mask: torch.Tensor # shape (bsz, seq_len).
audio_out_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
audio_out_ids_start: Optional[torch.LongTensor] # shape (num_audio_out,)
audio_out_ids_start_group_loc: Optional[torch.LongTensor] # shape (num_audio_out,), specify which a sample's group location in the batch
audio_in_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_in_total_length)
audio_in_ids_start: Optional[torch.LongTensor] # shape (num_audio_in,)
label_ids: Optional[torch.LongTensor] # shape (bsz, seq_len)
label_audio_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
reward: Optional[float] = None
@dataclass
class ChatMLDatasetSample:
input_ids: torch.LongTensor # (seq_len,) Input text tokens
label_ids: torch.LongTensor # (seq_len,) Label IDs
audio_ids_concat: torch.LongTensor # (num_codebooks, audio_seq_len) Concatenated audio tokens
audio_ids_start: torch.LongTensor # (num_audios,) Start index of each audio token in `audio_ids_concat`
audio_waveforms_concat: torch.Tensor # (total_wv_length,) Concatenated audio waveforms
audio_waveforms_start: torch.LongTensor # (num_audios,) Start index of each waveform in `audio_waveforms_concat`
audio_sample_rate: torch.Tensor # (num_audios,) Sampling rate per audio waveform
audio_speaker_indices: torch.LongTensor # (num_audios,) Speaker indices per audio; -1 = unknown
audio_label_ids_concat: Optional[torch.LongTensor] = None # (num_codebooks, audio_seq_len) Optional audio token labels
reward: Optional[float] = None # Optional scalar reward
def num_audios(self):
return max(len(self.audio_waveforms_start), len(self.audio_ids_start))
def get_audio_codes(self, idx):
code_start = self.audio_ids_start[idx]
if idx < len(self.audio_ids_start) - 1:
code_end = self.audio_ids_start[idx + 1]
else:
code_end = self.audio_ids_concat.shape[-1]
return self.audio_ids_concat[:, code_start:code_end]
def get_audio_codes_labels(self, idx):
if self.audio_label_ids_concat is None:
return None
code_start = self.audio_ids_start[idx]
if idx < len(self.audio_ids_start) - 1:
code_end = self.audio_ids_start[idx + 1]
else:
code_end = self.audio_ids_concat.shape[-1]
return self.audio_label_ids_concat[:, code_start:code_end]
def get_wv(self, idx):
wv_start = self.audio_waveforms_start[idx]
sr = self.audio_sample_rate[idx]
if idx < len(self.audio_waveforms_start) - 1:
wv_end = self.audio_waveforms_start[idx + 1]
else:
wv_end = self.audio_waveforms_concat.shape[-1]
return self.audio_waveforms_concat[wv_start:wv_end], sr
class HiggsAudioSampleCollator:
def __init__(
self,
audio_in_token_id,
audio_out_token_id,
pad_token_id,
audio_stream_bos_id,
audio_stream_eos_id,
round_to=8,
pad_left=False,
return_audio_in_tokens=True,
audio_num_codebooks=None,
use_delay_pattern=False,
disable_audio_codes_transform=False,
add_new_bos_eos_for_long_chunk=True,
mask_audio_out_token_label=True,
):
self.round_to = round_to
self.pad_left = pad_left
self.audio_in_token_id = audio_in_token_id
self.audio_out_token_id = audio_out_token_id
self.audio_stream_bos_id = audio_stream_bos_id
self.audio_stream_eos_id = audio_stream_eos_id
self.pad_token_id = pad_token_id
self.return_audio_in_tokens = return_audio_in_tokens
self.audio_num_codebooks = audio_num_codebooks
self.use_delay_pattern = use_delay_pattern
self.disable_audio_codes_transform = disable_audio_codes_transform
self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk
self.mask_audio_out_token_label = mask_audio_out_token_label
def _process_and_duplicate_audio_tokens(
self, input_ids: torch.Tensor, audio_idx: int, wv: torch.Tensor, labels: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, int]:
total_samples = len(wv)
num_chunks = math.ceil(total_samples / self.chunk_size_samples)
if num_chunks <= 1:
return input_ids, labels, 1
audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2]
duplicated_sequence = audio_token_seq.repeat(num_chunks)
new_input_ids = torch.cat([input_ids[: audio_idx - 1], duplicated_sequence, input_ids[audio_idx + 2 :]])
new_labels = None
if labels is not None:
label_seq = labels[audio_idx - 1 : audio_idx + 2]
duplicated_labels = label_seq.repeat(num_chunks)
new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]])
return new_input_ids, new_labels, num_chunks
def __call__(self, batch: List[ChatMLDatasetSample]):
label_ids = None
label_audio_ids = None
if all([ele.label_ids is None for ele in batch]):
return_labels = False
else:
return_labels = True
processed_batch = batch
# Get the max sequence length based on processed batch
max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to)
# Get the ids for audio-in and audio-out for each batch
audio_in_ids_l = []
audio_out_ids_l = []
audio_out_ids_group_loc_l = []
audio_in_label_ids_l = None
audio_out_label_ids_l = None
reward_l = []
if return_labels:
audio_out_no_train_flag = [] # Whether the audio-out data should be trained on or not.
# Process the audio inputs and outputs
for i in range(len(processed_batch)):
audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id
audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id
audio_ids = torch.ones_like(processed_batch[i].input_ids)
audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1
audio_in_ids = audio_ids[audio_in_mask]
audio_out_ids = audio_ids[audio_out_mask]
if return_labels:
audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0)
if self.mask_audio_out_token_label:
processed_batch[i].label_ids[audio_out_mask] = -100
if self.return_audio_in_tokens:
audio_in_ids_l.extend(
[processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids]
)
if processed_batch[i].audio_label_ids_concat is not None:
if audio_in_label_ids_l is None:
audio_in_label_ids_l = []
audio_in_label_ids_l.extend(
[
processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
for idx in audio_in_ids
]
)
audio_out_ids_l.extend(
[processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids]
)
audio_out_ids_group_loc_l.append(i)
if processed_batch[i].reward is not None:
reward_l.append(processed_batch[i].reward)
if processed_batch[i].audio_label_ids_concat is not None:
if audio_out_label_ids_l is None:
audio_out_label_ids_l = []
audio_out_label_ids_l.extend(
[
processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
for idx in audio_out_ids
]
)
if return_labels:
audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)
if len(audio_in_ids_l) > 0:
# I tried to remove the for-loop in original implementation
# but to do batching with padding caused problem so I turned it into a list compre.
lengths = [seg.shape[1] for seg in audio_in_ids_l]
aug_lengths = [length + 2 for length in lengths]
audio_in_ids_start = torch.cumsum(
torch.tensor([0] + aug_lengths[:-1], dtype=torch.long), dim=0
)
if self.disable_audio_codes_transform:
audio_in_ids = torch.cat(audio_in_ids_l, dim=1).long()
else:
with_tokens = [
torch.cat([
torch.full((seg.shape[0], 1), self.audio_stream_bos_id, dtype=torch.long),
seg,
torch.full((seg.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long),
], dim=1)
for seg in audio_in_ids_l
]
if self.use_delay_pattern:
with_tokens = [
build_delay_pattern_mask(
tok.unsqueeze(0),
bos_token_id=self.audio_stream_bos_id,
pad_token_id=self.audio_stream_eos_id
)[0]
for tok in with_tokens
]
audio_in_ids = torch.cat(with_tokens, dim=1).long()
else:
audio_in_ids = torch.zeros((0, 0), dtype=torch.long)
audio_in_ids_start = torch.zeros(0, dtype=torch.long)
audio_out_ids_start_group_loc = None
if len(audio_out_ids_l) > 0:
new_audio_out_ids_l = []
label_audio_ids_l = []
for idx, ele in enumerate(audio_out_ids_l):
if self.disable_audio_codes_transform:
audio_codes = ele
if return_labels:
label_audio_ids = audio_out_label_ids_l[idx]
else:
audio_codes = torch.cat(
[
torch.full((ele.shape[0], 1), self.audio_stream_bos_id, dtype=torch.long),
ele,
torch.full((ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long),
],
dim=1,
)
if return_labels:
label_audio_ids = torch.cat(
[
torch.full((ele.shape[0], 1), -100, dtype=torch.long),
ele,
torch.full((ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long),
],
dim=1,
)
if self.use_delay_pattern:
audio_codes = build_delay_pattern_mask(
audio_codes.unsqueeze(0),
bos_token_id=self.audio_stream_bos_id,
pad_token_id=self.audio_stream_eos_id,
)[0].squeeze(0)
if return_labels:
label_audio_ids = build_delay_pattern_mask(
label_audio_ids.unsqueeze(0),
bos_token_id=-100,
pad_token_id=-100,
)[0].squeeze(0)
new_audio_out_ids_l.append(audio_codes)
if return_labels:
if audio_out_no_train_flag[idx]:
label_audio_ids[:] = -100
label_audio_ids_l.append(label_audio_ids)
audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long()
if return_labels:
label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long()
audio_out_ids_start = torch.cumsum(
torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]), dim=0
)
audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long)
else:
audio_out_ids = torch.zeros((0, 0), dtype=torch.long)
audio_out_ids_start = torch.zeros(0, dtype=torch.long)
if return_labels:
label_audio_ids = torch.zeros((0, 0), dtype=torch.long)
reward = torch.tensor(reward_l, dtype=torch.float32)
# cleaner and faster implementation
def pad_sequence(seq, length, pad_value):
if self.pad_left:
padding = (length - len(seq), 0)
else:
padding = (0, length - len(seq))
return F.pad(seq, padding, value=pad_value)
input_ids = torch.stack([
pad_sequence(ele.input_ids, max_seq_length, self.pad_token_id)
for ele in processed_batch
])
if return_labels:
label_ids = torch.stack([
pad_sequence(ele.label_ids, max_seq_length, -100)
for ele in processed_batch
])
attention_mask = torch.stack([
pad_sequence(torch.ones_like(ele.input_ids), max_seq_length, 0)
for ele in processed_batch
])
if not self.return_audio_in_tokens:
audio_in_ids = None
audio_in_ids_start = None
if self.audio_num_codebooks is not None:
if audio_in_ids is not None:
audio_in_ids = audio_in_ids[: self.audio_num_codebooks]
if audio_out_ids is not None:
audio_out_ids = audio_out_ids[: self.audio_num_codebooks]
if label_audio_ids is not None:
label_audio_ids = label_audio_ids[: self.audio_num_codebooks]
return HiggsAudioBatchInput(
input_ids=input_ids,
attention_mask=attention_mask,
audio_out_ids=audio_out_ids,
audio_out_ids_start=audio_out_ids_start,
audio_out_ids_start_group_loc=audio_out_ids_start_group_loc,
audio_in_ids=audio_in_ids,
audio_in_ids_start=audio_in_ids_start,
label_ids=label_ids,
label_audio_ids=label_audio_ids,
reward=reward,
)