from typing import ( Dict, List, Optional, Union, Tuple, MutableMapping, Any, Mapping, Collection, get_type_hints, get_args, get_origin ) from dataclasses import dataclass, fields, _FIELDS, _FIELD, _FIELD_INITVAR, is_dataclass import torch.nn.functional as F import numpy as np import torch import math from functools import lru_cache __MAX_SIZE = 2048 def _ceil_to_nearest(n, round_to): return (n + round_to - 1) // round_to * round_to @torch.jit.script def build_delay_pattern_mask( input_ids: torch.Tensor, bos_token_id: int, pad_token_id: int, ): bsz, num_codebooks, seq_len = input_ids.shape new_seq_len = seq_len + num_codebooks - 1 input_ids_with_gen_mask = torch.ones((bsz, num_codebooks, new_seq_len), dtype=torch.long, device=input_ids.device) bos_mask = torch.tril(input_ids_with_gen_mask, -1) > 0 eos_mask = torch.triu(input_ids_with_gen_mask, seq_len) > 0 input_ids_with_gen_mask[bos_mask] = bos_token_id input_ids_with_gen_mask[(~bos_mask) & (~eos_mask)] = input_ids.reshape(-1) input_ids = input_ids_with_gen_mask.clone() input_ids[eos_mask] = pad_token_id input_ids_with_gen_mask[eos_mask] = -1 return input_ids, input_ids_with_gen_mask # implementation of dacite's from_dict with only necessary parts for ChatML @lru_cache(maxsize=None) def cache(function): return lru_cache(maxsize=__MAX_SIZE, typed=True)(function) @cache def get_fields(data_class): fields = getattr(data_class, _FIELDS) return [f for f in fields.values() if f._field_type is _FIELD or f._field_type is _FIELD_INITVAR] def is_optional(type_) -> bool: return get_origin(type_) is Union and type(None) in get_args(type_) def orig(data_class) -> Any: if is_dataclass(data_class): return data_class return get_origin(data_class) @cache def extract_generic(type_, defaults: Tuple = ()) -> tuple: try: if getattr(type_, "_special", False): return defaults if type_.__args__ == (): return (type_.__args__,) return type_.__args__ or defaults # type: ignore except AttributeError: return defaults def _build_value_for_collection(collection, data: Any) -> Any: if isinstance(data, Mapping): value_type = extract_generic(collection, defaults=(Any, Any))[1] return { key: _build_value(type_=value_type, data=value) for key, value in data.items() } elif isinstance(data, Collection) and not isinstance(data, (str, bytes, Mapping)): item_type = extract_generic(collection, defaults=(Any,))[0] return [ _build_value(type_=item_type, data=item) for item in data ] return data def _build_value(type_, data) -> Any: if is_optional(type_) and data is None: return data if get_origin(type_) is Union: data = _build_value_for_union(union=type_, data=data) elif hasattr(type_, "__origin__"): data = _build_value_for_collection(collection=type_, data=data) elif cache(is_dataclass)(orig(type_)) and isinstance(data, Mapping): data = from_dict(data_class=type_, data=data) return data def _build_value_for_union(union: type, data: Any) -> Any: for inner_type in get_args(union): if data is None and inner_type is type(None): return None try: return _build_value(inner_type, data) except Exception: continue raise ValueError(f"Cannot match {data!r} to any type in {union}") def is_instance(value: Any, type_) -> bool: if type_ is Any: return True origin = get_origin(type_) args = get_args(type_) if origin is Union: return any(is_instance(value, arg) for arg in args) if origin in (list, List): if not isinstance(value, list): return False (elem_type,) = args or (Any,) return all(is_instance(item, elem_type) for item in value) if origin in (dict, Dict, Mapping): if not isinstance(value, dict): return False key_type, val_type = args or (Any, Any) return all( is_instance(k, key_type) and is_instance(v, val_type) for k, v in value.items() ) try: return isinstance(value, type_) except TypeError: return False def from_dict(data_class, data): init_values: MutableMapping[str, Any] = {} post_init_values: MutableMapping[str, Any] = {} data_class_hints = get_type_hints(data_class) data_class_fields = cache(get_fields)(data_class) extra_fields = set(data.keys()) - {f.name for f in data_class_fields} if extra_fields: formatted_keys = ", ".join(f'"{key}"' for key in extra_fields) raise ValueError(f"cannot match {formatted_keys} to any data class field") for field in data_class_fields: field_type = data_class_hints[field.name] key = field.name if key in data: try: value = _build_value(type_=field_type, data=data[key]) except Exception as error: raise ValueError(error) if not is_instance(value, field_type): raise ValueError(( f'wrong value type for field "{field.name}" - should be "{field_type}" ' f'instead of value "{value}" of type "{type(value)}"' )) init_values[field.name] = value instance = data_class(**init_values) for key, value in post_init_values.items(): setattr(instance, key, value) return instance def normalize_chinese_punctuation(text): """ Convert Chinese (full-width) punctuation marks to English (half-width) equivalents. """ # Mapping of Chinese punctuation to English punctuation chinese_to_english_punct = { ",": ", ", # comma "。": ".", # period ":": ":", # colon ";": ";", # semicolon "?": "?", # question mark "!": "!", # exclamation mark "(": "(", # left parenthesis ")": ")", # right parenthesis "【": "[", # left square bracket "】": "]", # right square bracket "《": "<", # left angle quote "》": ">", # right angle quote "“": '"', # left double quotation "”": '"', # right double quotation "‘": "'", # left single quotation "’": "'", # right single quotation "、": ",", # enumeration comma "—": "-", # em dash "…": "...", # ellipsis "·": ".", # middle dot "「": '"', # left corner bracket "」": '"', # right corner bracket "『": '"', # left double corner bracket "』": '"', # right double corner bracket } # Replace each Chinese punctuation with its English counterpart for zh_punct, en_punct in chinese_to_english_punct.items(): text = text.replace(zh_punct, en_punct) return text def transcript_normalize(text: str): transcript = normalize_chinese_punctuation(text) transcript = transcript.replace("(", " ") transcript = transcript.replace(")", " ") transcript = transcript.replace("°F", " degrees Fahrenheit") transcript = transcript.replace("°C", " degrees Celsius") for tag, replacement in [ ("[laugh]", "[Laughter]"), ("[humming start]", "[Humming]"), ("[humming end]", "[Humming]"), ("[music start]", "[Music]"), ("[music end]", "[Music]"), ("[music]", "[Music]"), ("[sing start]", "[Singing]"), ("[sing end]", "[Singing]"), ("[applause]", "[Applause]"), ("[cheering]", "[Cheering]"), ("[cough]", "[Cough]"), ]: transcript = transcript.replace(tag, replacement) lines = transcript.split("\n") transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()]) transcript = transcript.strip() if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "", ""]]): transcript += "." return transcript @dataclass class AudioContent: audio_url: str raw_audio: Optional[str] = None offset: Optional[float] = None duration: Optional[float] = None row_id: Optional[int] = None type: str = "audio" @dataclass class TextContent: text: str type: str = "text" @dataclass class Message: role: str content: Union[str, AudioContent, TextContent, List[Union[str, AudioContent, TextContent]]] recipient: Optional[str] = None @dataclass class ChatMLSample: messages: List[Message] start_index: Optional[int] = None misc: Optional[Dict] = None speaker: Optional[str] = None def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer): try: if not isinstance(sample, ChatMLSample): # replacing pd.isna def is_nan(x): if isinstance(x, float): return math.isnan(x) if isinstance(x, np.generic): return np.isnan(x) if isinstance(x, torch.Tensor) and x.numel() == 1: return torch.isnan(x).item() return False if "speaker" in sample and is_nan(sample["speaker"]): sample["speaker"] = None if "start_index" in sample and is_nan(sample["start_index"]): sample["start_index"] = None if "content" in sample and is_nan(sample["content"]): sample["content"] = "" def convert_nan_to_none(obj): if isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, float) and math.isnan(obj): return None elif isinstance(obj, dict): return {k: convert_nan_to_none(v) for k, v in obj.items()} elif isinstance(obj, (list, tuple)): return [convert_nan_to_none(item) for item in obj] return obj clean_sample = convert_nan_to_none(sample) val_keys = [] for field in fields(ChatMLSample): if field.name in clean_sample: val_keys.append(field.name) clean_sample = {k: clean_sample[k] for k in val_keys} sample = from_dict( data_class=ChatMLSample, data=clean_sample, ) input_tokens = [] audio_contents = [] speaker_id = None if sample.speaker is not None: speaker_id = sample.speaker elif sample.misc is not None: if "speaker" in sample.misc: speaker_id = sample.misc["speaker"] total_m = len(sample.messages) for turn_id, message in enumerate(sample.messages): role = message.role recipient = message.recipient content = message.content content_l = [] if isinstance(content, str): content_l.append(TextContent(text=content)) elif isinstance(content, TextContent): content_l.append(content) elif isinstance(content, AudioContent): content_l.append(content) elif isinstance(content, list): for ele in content: if isinstance(ele, str): content_l.append(TextContent(text=ele)) else: content_l.append(ele) if turn_id == 0: prefix = f"<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n" else: prefix = f"<|start_header_id|>{role}<|end_header_id|>\n\n" eot_postfix = "<|eot_id|>" eom_postfix = "<|eom_id|>" prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False) input_tokens.extend(prefix_tokens) if recipient: assert role == "assistant", "Recipient is only available for assistant role." recipient_tokens = tokenizer.encode(f"{recipient}<|recipient|>", add_special_tokens=False) input_tokens.extend(recipient_tokens) for content in content_l: if content.type == "text": text_tokens = tokenizer.encode(content.text, add_special_tokens=False) input_tokens.extend(text_tokens) elif content.type == "audio": audio_contents.append(content) if role == "user" or role == "system": text_tokens = tokenizer.encode( "<|audio_bos|><|AUDIO|><|audio_eos|>", add_special_tokens=False, ) input_tokens.extend(text_tokens) elif role == "assistant": text_tokens = tokenizer.encode( "<|audio_out_bos|><|AUDIO_OUT|><|audio_eos|>", add_special_tokens=False, ) input_tokens.extend(text_tokens) next_id = turn_id + 1 if role == "assistant" and next_id != total_m and sample.messages[next_id].role == "assistant": postfix_tokens = tokenizer.encode(eom_postfix, add_special_tokens=False) input_tokens.extend(postfix_tokens) else: postfix_tokens = tokenizer.encode(eot_postfix, add_special_tokens=False) input_tokens.extend(postfix_tokens) return input_tokens, audio_contents, speaker_id except Exception: return None, None, None @dataclass class HiggsAudioBatchInput: input_ids: torch.LongTensor # shape (bsz, seq_len). attention_mask: torch.Tensor # shape (bsz, seq_len). audio_out_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length) audio_out_ids_start: Optional[torch.LongTensor] # shape (num_audio_out,) audio_out_ids_start_group_loc: Optional[torch.LongTensor] # shape (num_audio_out,), specify which a sample's group location in the batch audio_in_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_in_total_length) audio_in_ids_start: Optional[torch.LongTensor] # shape (num_audio_in,) label_ids: Optional[torch.LongTensor] # shape (bsz, seq_len) label_audio_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length) reward: Optional[float] = None @dataclass class ChatMLDatasetSample: input_ids: torch.LongTensor # (seq_len,) Input text tokens label_ids: torch.LongTensor # (seq_len,) Label IDs audio_ids_concat: torch.LongTensor # (num_codebooks, audio_seq_len) Concatenated audio tokens audio_ids_start: torch.LongTensor # (num_audios,) Start index of each audio token in `audio_ids_concat` audio_waveforms_concat: torch.Tensor # (total_wv_length,) Concatenated audio waveforms audio_waveforms_start: torch.LongTensor # (num_audios,) Start index of each waveform in `audio_waveforms_concat` audio_sample_rate: torch.Tensor # (num_audios,) Sampling rate per audio waveform audio_speaker_indices: torch.LongTensor # (num_audios,) Speaker indices per audio; -1 = unknown audio_label_ids_concat: Optional[torch.LongTensor] = None # (num_codebooks, audio_seq_len) Optional audio token labels reward: Optional[float] = None # Optional scalar reward def num_audios(self): return max(len(self.audio_waveforms_start), len(self.audio_ids_start)) def get_audio_codes(self, idx): code_start = self.audio_ids_start[idx] if idx < len(self.audio_ids_start) - 1: code_end = self.audio_ids_start[idx + 1] else: code_end = self.audio_ids_concat.shape[-1] return self.audio_ids_concat[:, code_start:code_end] def get_audio_codes_labels(self, idx): if self.audio_label_ids_concat is None: return None code_start = self.audio_ids_start[idx] if idx < len(self.audio_ids_start) - 1: code_end = self.audio_ids_start[idx + 1] else: code_end = self.audio_ids_concat.shape[-1] return self.audio_label_ids_concat[:, code_start:code_end] def get_wv(self, idx): wv_start = self.audio_waveforms_start[idx] sr = self.audio_sample_rate[idx] if idx < len(self.audio_waveforms_start) - 1: wv_end = self.audio_waveforms_start[idx + 1] else: wv_end = self.audio_waveforms_concat.shape[-1] return self.audio_waveforms_concat[wv_start:wv_end], sr class HiggsAudioSampleCollator: def __init__( self, audio_in_token_id, audio_out_token_id, pad_token_id, audio_stream_bos_id, audio_stream_eos_id, round_to=8, pad_left=False, return_audio_in_tokens=True, audio_num_codebooks=None, use_delay_pattern=False, disable_audio_codes_transform=False, add_new_bos_eos_for_long_chunk=True, mask_audio_out_token_label=True, ): self.round_to = round_to self.pad_left = pad_left self.audio_in_token_id = audio_in_token_id self.audio_out_token_id = audio_out_token_id self.audio_stream_bos_id = audio_stream_bos_id self.audio_stream_eos_id = audio_stream_eos_id self.pad_token_id = pad_token_id self.return_audio_in_tokens = return_audio_in_tokens self.audio_num_codebooks = audio_num_codebooks self.use_delay_pattern = use_delay_pattern self.disable_audio_codes_transform = disable_audio_codes_transform self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk self.mask_audio_out_token_label = mask_audio_out_token_label def _process_and_duplicate_audio_tokens( self, input_ids: torch.Tensor, audio_idx: int, wv: torch.Tensor, labels: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, int]: total_samples = len(wv) num_chunks = math.ceil(total_samples / self.chunk_size_samples) if num_chunks <= 1: return input_ids, labels, 1 audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2] duplicated_sequence = audio_token_seq.repeat(num_chunks) new_input_ids = torch.cat([input_ids[: audio_idx - 1], duplicated_sequence, input_ids[audio_idx + 2 :]]) new_labels = None if labels is not None: label_seq = labels[audio_idx - 1 : audio_idx + 2] duplicated_labels = label_seq.repeat(num_chunks) new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]]) return new_input_ids, new_labels, num_chunks def __call__(self, batch: List[ChatMLDatasetSample]): label_ids = None label_audio_ids = None if all([ele.label_ids is None for ele in batch]): return_labels = False else: return_labels = True processed_batch = batch # Get the max sequence length based on processed batch max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to) # Get the ids for audio-in and audio-out for each batch audio_in_ids_l = [] audio_out_ids_l = [] audio_out_ids_group_loc_l = [] audio_in_label_ids_l = None audio_out_label_ids_l = None reward_l = [] if return_labels: audio_out_no_train_flag = [] # Whether the audio-out data should be trained on or not. # Process the audio inputs and outputs for i in range(len(processed_batch)): audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id audio_ids = torch.ones_like(processed_batch[i].input_ids) audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1 audio_in_ids = audio_ids[audio_in_mask] audio_out_ids = audio_ids[audio_out_mask] if return_labels: audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0) if self.mask_audio_out_token_label: processed_batch[i].label_ids[audio_out_mask] = -100 if self.return_audio_in_tokens: audio_in_ids_l.extend( [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids] ) if processed_batch[i].audio_label_ids_concat is not None: if audio_in_label_ids_l is None: audio_in_label_ids_l = [] audio_in_label_ids_l.extend( [ processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids ] ) audio_out_ids_l.extend( [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids] ) audio_out_ids_group_loc_l.append(i) if processed_batch[i].reward is not None: reward_l.append(processed_batch[i].reward) if processed_batch[i].audio_label_ids_concat is not None: if audio_out_label_ids_l is None: audio_out_label_ids_l = [] audio_out_label_ids_l.extend( [ processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids ] ) if return_labels: audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0) if len(audio_in_ids_l) > 0: # I tried to remove the for-loop in original implementation # but to do batching with padding caused problem so I turned it into a list compre. lengths = [seg.shape[1] for seg in audio_in_ids_l] aug_lengths = [length + 2 for length in lengths] audio_in_ids_start = torch.cumsum( torch.tensor([0] + aug_lengths[:-1], dtype=torch.long), dim=0 ) if self.disable_audio_codes_transform: audio_in_ids = torch.cat(audio_in_ids_l, dim=1).long() else: with_tokens = [ torch.cat([ torch.full((seg.shape[0], 1), self.audio_stream_bos_id, dtype=torch.long), seg, torch.full((seg.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long), ], dim=1) for seg in audio_in_ids_l ] if self.use_delay_pattern: with_tokens = [ build_delay_pattern_mask( tok.unsqueeze(0), bos_token_id=self.audio_stream_bos_id, pad_token_id=self.audio_stream_eos_id )[0] for tok in with_tokens ] audio_in_ids = torch.cat(with_tokens, dim=1).long() else: audio_in_ids = torch.zeros((0, 0), dtype=torch.long) audio_in_ids_start = torch.zeros(0, dtype=torch.long) audio_out_ids_start_group_loc = None if len(audio_out_ids_l) > 0: new_audio_out_ids_l = [] label_audio_ids_l = [] for idx, ele in enumerate(audio_out_ids_l): if self.disable_audio_codes_transform: audio_codes = ele if return_labels: label_audio_ids = audio_out_label_ids_l[idx] else: audio_codes = torch.cat( [ torch.full((ele.shape[0], 1), self.audio_stream_bos_id, dtype=torch.long), ele, torch.full((ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long), ], dim=1, ) if return_labels: label_audio_ids = torch.cat( [ torch.full((ele.shape[0], 1), -100, dtype=torch.long), ele, torch.full((ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long), ], dim=1, ) if self.use_delay_pattern: audio_codes = build_delay_pattern_mask( audio_codes.unsqueeze(0), bos_token_id=self.audio_stream_bos_id, pad_token_id=self.audio_stream_eos_id, )[0].squeeze(0) if return_labels: label_audio_ids = build_delay_pattern_mask( label_audio_ids.unsqueeze(0), bos_token_id=-100, pad_token_id=-100, )[0].squeeze(0) new_audio_out_ids_l.append(audio_codes) if return_labels: if audio_out_no_train_flag[idx]: label_audio_ids[:] = -100 label_audio_ids_l.append(label_audio_ids) audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long() if return_labels: label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long() audio_out_ids_start = torch.cumsum( torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]), dim=0 ) audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long) else: audio_out_ids = torch.zeros((0, 0), dtype=torch.long) audio_out_ids_start = torch.zeros(0, dtype=torch.long) if return_labels: label_audio_ids = torch.zeros((0, 0), dtype=torch.long) reward = torch.tensor(reward_l, dtype=torch.float32) # cleaner and faster implementation def pad_sequence(seq, length, pad_value): if self.pad_left: padding = (length - len(seq), 0) else: padding = (0, length - len(seq)) return F.pad(seq, padding, value=pad_value) input_ids = torch.stack([ pad_sequence(ele.input_ids, max_seq_length, self.pad_token_id) for ele in processed_batch ]) if return_labels: label_ids = torch.stack([ pad_sequence(ele.label_ids, max_seq_length, -100) for ele in processed_batch ]) attention_mask = torch.stack([ pad_sequence(torch.ones_like(ele.input_ids), max_seq_length, 0) for ele in processed_batch ]) if not self.return_audio_in_tokens: audio_in_ids = None audio_in_ids_start = None if self.audio_num_codebooks is not None: if audio_in_ids is not None: audio_in_ids = audio_in_ids[: self.audio_num_codebooks] if audio_out_ids is not None: audio_out_ids = audio_out_ids[: self.audio_num_codebooks] if label_audio_ids is not None: label_audio_ids = label_audio_ids[: self.audio_num_codebooks] return HiggsAudioBatchInput( input_ids=input_ids, attention_mask=attention_mask, audio_out_ids=audio_out_ids, audio_out_ids_start=audio_out_ids_start, audio_out_ids_start_group_loc=audio_out_ids_start_group_loc, audio_in_ids=audio_in_ids, audio_in_ids_start=audio_in_ids_start, label_ids=label_ids, label_audio_ids=label_audio_ids, reward=reward, )