styling fixes

This commit is contained in:
Yousef Rafat 2025-09-06 01:17:04 +03:00
parent 6e9335d638
commit 57c15f970c
12 changed files with 111 additions and 113 deletions

View File

@ -19,27 +19,27 @@ def estimate_autoregressive_vram(
max_seq_len: int, max_seq_len: int,
batch_size: int = 1, batch_size: int = 1,
dtype = torch.float16, dtype = torch.float16,
intermediate_factor: float = 4.0, intermediate_factor: float = 4.0,
device = torch.device('cuda') device = torch.device('cuda')
) -> bool: ) -> bool:
dtype_size = torch.finfo(dtype).bits // 8 dtype_size = torch.finfo(dtype).bits // 8
kv_cache_bytes = num_layers * max_seq_len * hidden_dim * 2 * batch_size * dtype_size kv_cache_bytes = num_layers * max_seq_len * hidden_dim * 2 * batch_size * dtype_size
# we only calculate hidden states in cuda graphs, so we don't care about the output logits # we only calculate hidden states in cuda graphs, so we don't care about the output logits
input_bytes = output_bytes = batch_size * max_seq_len * hidden_dim * dtype_size input_bytes = output_bytes = batch_size * max_seq_len * hidden_dim * dtype_size
# rough calculation for activation sizes # rough calculation for activation sizes
intermediate_bytes = intermediate_factor * output_bytes intermediate_bytes = intermediate_factor * output_bytes
total_estimated = kv_cache_bytes + input_bytes + output_bytes + intermediate_bytes total_estimated = kv_cache_bytes + input_bytes + output_bytes + intermediate_bytes
# get vram info # get vram info
free_vram = get_free_memory(device) free_vram = get_free_memory(device)
minimum_vram = minimum_inference_memory() minimum_vram = minimum_inference_memory()
enough_vram = free_vram - minimum_vram >= total_estimated enough_vram = free_vram - minimum_vram >= total_estimated
return enough_vram return enough_vram
class TopKLogits: class TopKLogits:
@ -64,7 +64,7 @@ class TemperatureLogitsWarper:
def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
scores_processed = scores / self.temperature scores_processed = scores / self.temperature
return scores_processed return scores_processed
class TopPLogitsWarper: class TopPLogitsWarper:
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
top_p = float(top_p) top_p = float(top_p)
@ -175,7 +175,7 @@ class GenerationConfig:
config_dict = {key: value for key, value in config_dict.items() if value is not None} config_dict = {key: value for key, value in config_dict.items() if value is not None}
valid_fields = {f.name for f in fields(cls)} valid_fields = {f.name for f in fields(cls)}
filtered_args = {k: v for k, v in {**config_dict, **kwargs}.items() if k in valid_fields} filtered_args = {k: v for k, v in {**config_dict, **kwargs}.items() if k in valid_fields}
generation_config = cls(**filtered_args) generation_config = cls(**filtered_args)
@ -216,7 +216,7 @@ class AutoRegressiveGeneration:
self.model.cache_config = self.cache_config self.model.cache_config = self.cache_config
self.kv_caches = { self.kv_caches = {
length: StaticCache( length: StaticCache(
config=self.cache_config, config=self.cache_config,
max_batch_size = self.cache_config.max_batch, max_batch_size = self.cache_config.max_batch,
@ -234,8 +234,8 @@ class AutoRegressiveGeneration:
# cuda graphs only help if input shapes are constant # cuda graphs only help if input shapes are constant
if ( if (
device == "cuda" device == "cuda"
and hasattr(model, "capture_model") and hasattr(model, "capture_model")
and self.model.cache_implementation == "static" and self.model.cache_implementation == "static"
and self.model.use_kv_buckets and self.model.use_kv_buckets
and enough_vram and enough_vram
@ -247,7 +247,7 @@ class AutoRegressiveGeneration:
@torch.inference_mode() @torch.inference_mode()
def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length: int = 1024, min_new_length = 0, def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length: int = 1024, min_new_length = 0,
top_k: int = 50, top_p: float = 1.0, temperature: float = 1.0, do_sample: bool = False, seed = None, **kwargs): top_k: int = 50, top_p: float = 1.0, temperature: float = 1.0, do_sample: bool = False, seed = None, **kwargs):
if seed is not None: if seed is not None:
torch_generator = torch.Generator(device = input_ids.device).manual_seed(seed) torch_generator = torch.Generator(device = input_ids.device).manual_seed(seed)
else: else:
@ -335,7 +335,7 @@ class AutoRegressiveGeneration:
# TODO: have a default self._sample fn and a default check if the model supports autoregGen or not # TODO: have a default self._sample fn and a default check if the model supports autoregGen or not
if not hasattr(self.model, "_sample"): if not hasattr(self.model, "_sample"):
raise ValueError("Model doesn't support AutoRegressive Generation!") raise ValueError("Model doesn't support AutoRegressive Generation!")
self._prepare_kv_caches() self._prepare_kv_caches()
result = self.model._sample( result = self.model._sample(
@ -347,7 +347,7 @@ class AutoRegressiveGeneration:
) )
return result return result
def _prepare_kv_caches(self): def _prepare_kv_caches(self):
for kv_cache in self.kv_caches.values(): for kv_cache in self.kv_caches.values():
kv_cache.reset() kv_cache.reset()
@ -357,13 +357,13 @@ class AutoRegressiveGeneration:
return GenerationSampling.BEAM_SAMPLING return GenerationSampling.BEAM_SAMPLING
else: else:
return GenerationSampling.GREEDY_SEARCH return GenerationSampling.GREEDY_SEARCH
def _prepare_generated_length( def _prepare_generated_length(
self, self,
generation_config: GenerationConfig, generation_config: GenerationConfig,
input_ids_length, input_ids_length,
): ):
""" max_length = user_input_id_tokens + generation_max_length """ """ max_length = user_input_id_tokens + generation_max_length """
if generation_config.max_new_length is not None: if generation_config.max_new_length is not None:
@ -374,11 +374,11 @@ class AutoRegressiveGeneration:
generation_config.min_length = generation_config.min_new_length + input_ids_length generation_config.min_length = generation_config.min_new_length + input_ids_length
return generation_config return generation_config
def _get_cache( def _get_cache(
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
) -> Cache: ) -> Cache:
assert cache_implementation == "static", f"Only 'static' cache is supported, got {cache_implementation}" assert cache_implementation == "static", f"Only 'static' cache is supported, got {cache_implementation}"
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
@ -412,7 +412,7 @@ class AutoRegressiveGeneration:
return self.model._cache return self.model._cache
def _prepare_cache_for_generation( def _prepare_cache_for_generation(
self, self,
generation_config: GenerationConfig, generation_config: GenerationConfig,
@ -466,7 +466,7 @@ class AutoRegressiveGeneration:
model_kwargs = generation_config.update(**kwargs) model_kwargs = generation_config.update(**kwargs)
return generation_config, model_kwargs return generation_config, model_kwargs
def _validate_generated_length(self, generation_config: GenerationConfig, input_ids_length): def _validate_generated_length(self, generation_config: GenerationConfig, input_ids_length):
"""Performs validation related to the resulting generated length""" """Performs validation related to the resulting generated length"""
@ -498,7 +498,7 @@ class AutoRegressiveGeneration:
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
UserWarning, UserWarning,
) )
def _expand_inputs_for_generation( def _expand_inputs_for_generation(
self, self,
expand_size: int = 1, expand_size: int = 1,
@ -526,13 +526,13 @@ class AutoRegressiveGeneration:
model_kwargs = _expand_dict_for_generation(model_kwargs) model_kwargs = _expand_dict_for_generation(model_kwargs)
return input_ids, model_kwargs return input_ids, model_kwargs
def _prepare_special_tokens( def _prepare_special_tokens(
self, self,
generation_config: GenerationConfig, generation_config: GenerationConfig,
device: Optional[Union[torch.device, str]] = None, device: Optional[Union[torch.device, str]] = None,
): ):
def _tensor_or_none(token, device=None): def _tensor_or_none(token, device=None):
if token is None: if token is None:
return token return token
@ -564,7 +564,7 @@ class AutoRegressiveGeneration:
generation_config: GenerationConfig, generation_config: GenerationConfig,
model_kwargs: dict[str, Any], model_kwargs: dict[str, Any],
) -> torch.LongTensor: ) -> torch.LongTensor:
pad_token_id = generation_config._pad_token_tensor pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor eos_token_id = generation_config._eos_token_tensor
@ -593,12 +593,12 @@ class AutoRegressiveGeneration:
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
) )
return attention_mask return attention_mask
def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0, top_k=50, top_p=1.0, temperature=1.0, do_sample = False, seed=None, **kwargs): def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0, top_k=50, top_p=1.0, temperature=1.0, do_sample = False, seed=None, **kwargs):
# to work with BaseModel # to work with BaseModel
if hasattr(patcher, "model") and hasattr(patcher.model, "diffusion_model"): if hasattr(patcher, "model") and hasattr(patcher.model, "diffusion_model"):
model = patcher.model.diffusion_model model = patcher.model.diffusion_model
if node._cached_autoregressive_sampler is None or node._cached_autoregressive_sampler.model is not model: if node._cached_autoregressive_sampler is None or node._cached_autoregressive_sampler.model is not model:
if model.device != patcher.load_device: if model.device != patcher.load_device:
model = model.to(patcher.load_device, dtype=model.dtype) model = model.to(patcher.load_device, dtype=model.dtype)
@ -610,7 +610,7 @@ def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0,
kwargs.update({k: v for k, v in input_ids.items() if k != "input_ids"}) kwargs.update({k: v for k, v in input_ids.items() if k != "input_ids"})
else: else:
main_input_ids = input_ids main_input_ids = input_ids
device = node._cached_autoregressive_sampler.device device = node._cached_autoregressive_sampler.device
main_input_ids = main_input_ids.to(device) main_input_ids = main_input_ids.to(device)

View File

@ -22,7 +22,7 @@ class CUDAGraphRunner(nn.Module):
def capture(self, *args, **kwargs): def capture(self, *args, **kwargs):
assert self._graph is None assert self._graph is None
for _ in range(_NUM_WARMUP_ITERS): for _ in range(_NUM_WARMUP_ITERS):
self.model(*args, **kwargs) self.model(*args, **kwargs)

View File

@ -125,7 +125,7 @@ class IIRfilter(object):
@property @property
def b_and_a(self): def b_and_a(self):
return self.generate_coefficients() return self.generate_coefficients()
class Meter(torch.nn.Module): class Meter(torch.nn.Module):
def __init__( def __init__(
@ -227,7 +227,7 @@ class Meter(torch.nn.Module):
return unfolded return unfolded
def integrated_loudness(self, data: torch.Tensor): def integrated_loudness(self, data: torch.Tensor):
if not torch.is_tensor(data): if not torch.is_tensor(data):
data = torch.from_numpy(data).float() data = torch.from_numpy(data).float()
else: else:
@ -291,10 +291,10 @@ class Meter(torch.nn.Module):
def loudness( def loudness(
audio_data, sample_rate: int, target_loudness: int, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs audio_data, sample_rate: int, target_loudness: int, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs
): ):
MIN_LOUDNESS = -70 MIN_LOUDNESS = -70
device = audio_data.device device = audio_data.device
original_length = audio_data.shape[-1] original_length = audio_data.shape[-1]
signal_duration = original_length / sample_rate signal_duration = original_length / sample_rate

View File

@ -19,8 +19,8 @@ from collections import defaultdict, OrderedDict
from typing import Optional, Tuple, Union, List from typing import Optional, Tuple, Union, List
class GenerationMode(Enum): class GenerationMode(Enum):
TEXT = 0 TEXT = 0
AUDIO_INIT = 1 AUDIO_INIT = 1
AUDIO_IN_PROGRESS = 2 AUDIO_IN_PROGRESS = 2
def _ignore_causal_mask_sdpa( def _ignore_causal_mask_sdpa(
@ -413,7 +413,7 @@ class HiggsAudioModel(nn.Module):
def __init__(self, device = None, dtype = None, operations = None, **kwargs): def __init__(self, device = None, dtype = None, operations = None, **kwargs):
super().__init__() super().__init__()
self.padding_idx = kwargs["pad_token_id"] self.padding_idx = kwargs["pad_token_id"]
self.audio_in_token_idx = kwargs["audio_in_token_idx"] self.audio_in_token_idx = kwargs["audio_in_token_idx"]
self.audio_out_token_idx = kwargs["audio_out_token_idx"] self.audio_out_token_idx = kwargs["audio_out_token_idx"]
@ -439,7 +439,7 @@ class HiggsAudioModel(nn.Module):
self.audio_out_bos_token_id = 128013 self.audio_out_bos_token_id = 128013
self.audio_eos_token_id = 128012 self.audio_eos_token_id = 128012
text_config = kwargs["text_config"] text_config = kwargs["text_config"]
llama_config = Llama2Config(num_attention_heads = text_config["num_attention_heads"], llama_config = Llama2Config(num_attention_heads = text_config["num_attention_heads"],
num_key_value_heads = text_config["num_key_value_heads"], num_key_value_heads = text_config["num_key_value_heads"],
@ -616,7 +616,7 @@ class HiggsAudioModel(nn.Module):
next_audio_tokens = None next_audio_tokens = None
return next_tokens, next_audio_tokens return next_tokens, next_audio_tokens
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
@ -677,7 +677,7 @@ class HiggsAudioModel(nn.Module):
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)) causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True))
return causal_mask return causal_mask
def _embed_audio_ids(self, audio_ids): def _embed_audio_ids(self, audio_ids):
codebook_shift = ( codebook_shift = (
torch.arange(self.config["audio_num_codebooks"], device=audio_ids.device) * self.audio_codebook_size torch.arange(self.config["audio_num_codebooks"], device=audio_ids.device) * self.audio_codebook_size
@ -712,7 +712,7 @@ class HiggsAudioModel(nn.Module):
) )
audio_attention_mask = attention_mask.masked_fill(no_audio_out_mask, min_dtype) audio_attention_mask = attention_mask.masked_fill(no_audio_out_mask, min_dtype)
return fast_forward_attention_mask, audio_attention_mask return fast_forward_attention_mask, audio_attention_mask
def _forward_core( def _forward_core(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -728,7 +728,7 @@ class HiggsAudioModel(nn.Module):
is_using_cuda_graph: Optional[bool] = False, is_using_cuda_graph: Optional[bool] = False,
): ):
position_id_offset = cache_position[0] if use_cache else 0 position_id_offset = cache_position[0] if use_cache else 0
position_embeddings = self.rotary_emb(hidden_states, position_ids + position_id_offset) position_embeddings = self.rotary_emb(hidden_states, position_ids + position_id_offset)
for decoder_layer in self.layers: for decoder_layer in self.layers:
@ -927,7 +927,7 @@ class HiggsAudioModel(nn.Module):
) )
return ret return ret
def _update_model_kwargs_for_generation( def _update_model_kwargs_for_generation(
self, self,
outputs, outputs,
@ -956,13 +956,13 @@ class HiggsAudioModel(nn.Module):
) )
return model_kwargs return model_kwargs
def _copy_kv_cache(self, from_cache: Cache, to_cache: Cache): def _copy_kv_cache(self, from_cache: Cache, to_cache: Cache):
from_cache_size = from_cache.get_max_cache_shape() from_cache_size = from_cache.get_max_cache_shape()
assert to_cache.get_max_cache_shape() >= from_cache_size, ( assert to_cache.get_max_cache_shape() >= from_cache_size, (
f"The target cache size {to_cache.get_max_cache_shape()} is smaller than the source cache size {from_cache_size}." f"The target cache size {to_cache.get_max_cache_shape()} is smaller than the source cache size {from_cache_size}."
) )
n_layers = self.num_hidden_layers n_layers = self.num_hidden_layers
for i in range(n_layers): for i in range(n_layers):
@ -977,7 +977,7 @@ class HiggsAudioModel(nn.Module):
self.cache_config.head_dim), self.cache_config.head_dim),
device=self.device, dtype=self.dtype device=self.device, dtype=self.dtype
) )
if getattr(to_layer, "values", None) is None: if getattr(to_layer, "values", None) is None:
to_layer.values = torch.zeros( to_layer.values = torch.zeros(
(self.cache_config.max_batch, self.cache_config.num_key_value_heads, (self.cache_config.max_batch, self.cache_config.num_key_value_heads,
@ -1011,7 +1011,7 @@ class HiggsAudioModel(nn.Module):
f"The current sequence length {current_sequence_length} is larger than " f"The current sequence length {current_sequence_length} is larger than "
f"all past key values buckets {past_key_values_buckets.keys()}." f"all past key values buckets {past_key_values_buckets.keys()}."
) )
def _sample( def _sample(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
@ -1020,7 +1020,7 @@ class HiggsAudioModel(nn.Module):
past_key_values_buckets: Optional[OrderedDict[int, Cache]], past_key_values_buckets: Optional[OrderedDict[int, Cache]],
**model_kwargs, **model_kwargs,
): ):
# code supports only non-mixed batchs # code supports only non-mixed batchs
audio_out_bos_token_id = generation_config.generation_kwargs.get("audio_out_bos_token_id", None) audio_out_bos_token_id = generation_config.generation_kwargs.get("audio_out_bos_token_id", None)
@ -1069,7 +1069,7 @@ class HiggsAudioModel(nn.Module):
while not this_peer_finished: while not this_peer_finished:
eos_token_tensor = torch.tensor([self.config["text_config"]["eos_token_id"]], device=input_ids.device) eos_token_tensor = torch.tensor([self.config["text_config"]["eos_token_id"]], device=input_ids.device)
if input_ids[0][-1] == audio_out_bos_token_id: if input_ids[0][-1] == audio_out_bos_token_id:
generation_mode = GenerationMode.AUDIO_INIT generation_mode = GenerationMode.AUDIO_INIT
elif input_ids[0][-1] == self.audio_out_token_idx: elif input_ids[0][-1] == self.audio_out_token_idx:
@ -1211,7 +1211,7 @@ class HiggsAudioModel(nn.Module):
pbar.update(pbar.total - pbar.current) pbar.update(pbar.total - pbar.current)
return audio_sequences return audio_sequences
@torch.inference_mode() @torch.inference_mode()
def generate( def generate(
self, self,
@ -1222,7 +1222,7 @@ class HiggsAudioModel(nn.Module):
generation_functions = None, generation_functions = None,
**kwargs, **kwargs,
): ):
if generation_config is None: if generation_config is None:
generation_config = GenerationConfig() generation_config = GenerationConfig()

View File

@ -1,4 +1,4 @@
from typing import ( from typing import (
Dict, List, Optional, Union, Tuple, MutableMapping, Any, Mapping, Collection, get_type_hints, get_args, get_origin Dict, List, Optional, Union, Tuple, MutableMapping, Any, Mapping, Collection, get_type_hints, get_args, get_origin
) )
@ -153,13 +153,13 @@ def from_dict(data_class, data):
value = _build_value(type_=field_type, data=data[key]) value = _build_value(type_=field_type, data=data[key])
except Exception as error: except Exception as error:
raise ValueError(error) raise ValueError(error)
if not is_instance(value, field_type): if not is_instance(value, field_type):
raise ValueError(( raise ValueError((
f'wrong value type for field "{field.name}" - should be "{field_type}" ' f'wrong value type for field "{field.name}" - should be "{field_type}" '
f'instead of value "{value}" of type "{type(value)}"' f'instead of value "{value}" of type "{type(value)}"'
)) ))
init_values[field.name] = value init_values[field.name] = value
instance = data_class(**init_values) instance = data_class(**init_values)
@ -273,7 +273,7 @@ def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
try: try:
if not isinstance(sample, ChatMLSample): if not isinstance(sample, ChatMLSample):
# replacing pd.isna # replacing pd.isna
def is_nan(x): def is_nan(x):
if isinstance(x, float): if isinstance(x, float):
return math.isnan(x) return math.isnan(x)
@ -282,7 +282,7 @@ def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
if isinstance(x, torch.Tensor) and x.numel() == 1: if isinstance(x, torch.Tensor) and x.numel() == 1:
return torch.isnan(x).item() return torch.isnan(x).item()
return False return False
if "speaker" in sample and is_nan(sample["speaker"]): if "speaker" in sample and is_nan(sample["speaker"]):
sample["speaker"] = None sample["speaker"] = None
if "start_index" in sample and is_nan(sample["start_index"]): if "start_index" in sample and is_nan(sample["start_index"]):
@ -489,7 +489,7 @@ class HiggsAudioSampleCollator:
def _process_and_duplicate_audio_tokens( def _process_and_duplicate_audio_tokens(
self, input_ids: torch.Tensor, audio_idx: int, wv: torch.Tensor, labels: Optional[torch.Tensor] = None self, input_ids: torch.Tensor, audio_idx: int, wv: torch.Tensor, labels: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, int]: ) -> Tuple[torch.Tensor, torch.Tensor, int]:
total_samples = len(wv) total_samples = len(wv)
num_chunks = math.ceil(total_samples / self.chunk_size_samples) num_chunks = math.ceil(total_samples / self.chunk_size_samples)
@ -583,15 +583,15 @@ class HiggsAudioSampleCollator:
audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0) audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)
if len(audio_in_ids_l) > 0: if len(audio_in_ids_l) > 0:
# I tried to remove the for-loop in original implementation # I tried to remove the for-loop in original implementation
# but to do batching with padding caused problem so I turned it into a list compre. # but to do batching with padding caused problem so I turned it into a list compre.
lengths = [seg.shape[1] for seg in audio_in_ids_l] lengths = [seg.shape[1] for seg in audio_in_ids_l]
aug_lengths = [l + 2 for l in lengths] aug_lengths = [l + 2 for l in lengths]
audio_in_ids_start = torch.cumsum( audio_in_ids_start = torch.cumsum(
torch.tensor([0] + aug_lengths[:-1], dtype=torch.long), dim=0 torch.tensor([0] + aug_lengths[:-1], dtype=torch.long), dim=0
) )
if self.disable_audio_codes_transform: if self.disable_audio_codes_transform:
audio_in_ids = torch.cat(audio_in_ids_l, dim=1).long() audio_in_ids = torch.cat(audio_in_ids_l, dim=1).long()
else: else:
@ -607,7 +607,7 @@ class HiggsAudioSampleCollator:
if self.use_delay_pattern: if self.use_delay_pattern:
with_tokens = [ with_tokens = [
build_delay_pattern_mask( build_delay_pattern_mask(
tok.unsqueeze(0), tok.unsqueeze(0),
bos_token_id=self.audio_stream_bos_id, bos_token_id=self.audio_stream_bos_id,
pad_token_id=self.audio_stream_eos_id pad_token_id=self.audio_stream_eos_id
)[0] )[0]

View File

@ -160,7 +160,7 @@ class DACDecoder(nn.Module):
def forward(self, x): def forward(self, x):
return self.model(x) return self.model(x)
class Conv1d1x1: class Conv1d1x1:
def __new__(cls, in_channels, out_channels, bias=True, device=None, dtype=None, operations=None): def __new__(cls, in_channels, out_channels, bias=True, device=None, dtype=None, operations=None):
operations = operations or nn operations = operations or nn
@ -168,7 +168,7 @@ class Conv1d1x1:
in_channels, out_channels, kernel_size=1, in_channels, out_channels, kernel_size=1,
bias=bias, device=device, dtype=dtype bias=bias, device=device, dtype=dtype
) )
class Conv1d(nn.Module): class Conv1d(nn.Module):
def __init__( def __init__(
self, self,
@ -203,7 +203,7 @@ class Conv1d(nn.Module):
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
return x return x
class ConvTranspose1d(nn.Module): class ConvTranspose1d(nn.Module):
def __init__( def __init__(
self, self,
@ -237,7 +237,7 @@ class ConvTranspose1d(nn.Module):
def forward(self, x): def forward(self, x):
x = self.deconv(x) x = self.deconv(x)
return x return x
class ResidualUnit(nn.Module): class ResidualUnit(nn.Module):
def __init__( def __init__(
self, self,
@ -283,7 +283,7 @@ class EncoderBlock(nn.Module):
self.conv = Conv1d( self.conv = Conv1d(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size = kernel_size, kernel_size = kernel_size,
stride=stride, stride=stride,
bias=bias, bias=bias,
device = device, dtype = dtype, operations = operations device = device, dtype = dtype, operations = operations
@ -435,7 +435,7 @@ class Decoder(nn.Module):
x = self.conv_blocks[i](x) x = self.conv_blocks[i](x)
x = self.conv2(x) x = self.conv2(x)
return x return x
class HiggsAudioFeatureExtractor(nn.Module): class HiggsAudioFeatureExtractor(nn.Module):
def __init__(self, sampling_rate=16000): def __init__(self, sampling_rate=16000):
super().__init__() super().__init__()
@ -493,7 +493,7 @@ class EuclideanCodebook(nn.Module):
embed = self.embed.t() embed = self.embed.t()
if x.dtype != embed.dtype: if x.dtype != embed.dtype:
x = x.to(embed.dtype) x = x.to(embed.dtype)
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True)) dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
embed_ind = dist.max(dim=-1).indices embed_ind = dist.max(dim=-1).indices
return embed_ind return embed_ind
@ -520,14 +520,14 @@ class EuclideanCodebook(nn.Module):
return quantize return quantize
def forward(self, x): def forward(self, x):
orig_shape = x.shape # [B, T, D] orig_shape = x.shape # [B, T, D]
flat = x.view(-1, x.shape[-1]) # [B*T, D] flat = x.view(-1, x.shape[-1]) # [B*T, D]
embed_ind = self.quantize(flat) embed_ind = self.quantize(flat)
embed_ind = self.postprocess_emb(embed_ind, orig_shape) embed_ind = self.postprocess_emb(embed_ind, orig_shape)
# now embed_ind has shape [B, T] # now embed_ind has shape [B, T]
quantize = self.dequantize(embed_ind) quantize = self.dequantize(embed_ind)
# quantize: [B, T, D] # quantize: [B, T, D]
return quantize, embed_ind return quantize, embed_ind
@ -636,9 +636,9 @@ class ResidualVectorQuantization(nn.Module):
quantized = F.embedding(embed_id, codebook_weight).transpose(1, 2) # (B, D, T) quantized = F.embedding(embed_id, codebook_weight).transpose(1, 2) # (B, D, T)
quantized = F.linear(quantized.transpose(1, 2), proj_weight, proj_biases).transpose(1, 2) quantized = F.linear(quantized.transpose(1, 2), proj_weight, proj_biases).transpose(1, 2)
return quantized return quantized
codebook_weights = torch.stack([q._codebook.embed for q in self.layers]) # (n_codebooks, vocab_size, D) codebook_weights = torch.stack([q._codebook.embed for q in self.layers]) # (n_codebooks, vocab_size, D)
proj_weights = torch.stack([q.project_out.weight for q in self.layers]) proj_weights = torch.stack([q.project_out.weight for q in self.layers])
quantized = vmap(decode_one)(codebook_weights, proj_weights, q_indices, biases) quantized = vmap(decode_one)(codebook_weights, proj_weights, q_indices, biases)
@ -705,7 +705,7 @@ class ResidualVectorQuantizer(nn.Module):
"""Decode the given codes to the quantized representation.""" """Decode the given codes to the quantized representation."""
quantized = self.vq.decode(codes) quantized = self.vq.decode(codes)
return quantized return quantized
class HiggsAudioTokenizer(nn.Module): class HiggsAudioTokenizer(nn.Module):
def __init__( def __init__(
self, self,
@ -786,7 +786,7 @@ class HiggsAudioTokenizer(nn.Module):
x = x[:, 0, :] x = x[:, 0, :]
x = F.pad(x, (160, 160)) x = F.pad(x, (160, 160))
target = self.semantic_model(x, output_hidden_states=True).hidden_states target = self.semantic_model(x, output_hidden_states=True).hidden_states
target = torch.stack(target, dim=1) target = torch.stack(target, dim=1)
target = target.mean(1) target = target.mean(1)

View File

@ -413,7 +413,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["qkv_bias"] = False dit_config["qkv_bias"] = False
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
return dit_config return dit_config
if "{}layers.27.audio_post_attention_layernorm.weight".format(key_prefix) in state_dict_keys: if "{}layers.27.audio_post_attention_layernorm.weight".format(key_prefix) in state_dict_keys:
autoregressive_config = {} autoregressive_config = {}

View File

@ -1307,8 +1307,8 @@ class Higgsv2(supported_models_base.BASE):
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.Higgsv2(self, device=device) out = model_base.Higgsv2(self, device=device)
return out return out
def clip_target(self, state_dict = {}): def clip_target(self, state_dict = {}):
return supported_models_base.ClipTarget(comfy.text_encoders.higgsv2.DummyTokenizer, comfy.text_encoders.higgsv2.HiggsTokenizer) return supported_models_base.ClipTarget(comfy.text_encoders.higgsv2.DummyTokenizer, comfy.text_encoders.higgsv2.HiggsTokenizer)

View File

@ -13,7 +13,7 @@ class DummyTokenizer:
def revert_delay_pattern_vectorized(data: torch.Tensor) -> torch.Tensor: def revert_delay_pattern_vectorized(data: torch.Tensor) -> torch.Tensor:
num_codebooks, total_len = data.shape num_codebooks, total_len = data.shape
seq_len = total_len - num_codebooks + 1 seq_len = total_len - num_codebooks + 1
col_idx = torch.arange(seq_len, device=data.device)[None, :] \ col_idx = torch.arange(seq_len, device=data.device)[None, :] \
+ torch.arange(num_codebooks, device=data.device)[:, None] + torch.arange(num_codebooks, device=data.device)[:, None]
out = data[torch.arange(num_codebooks)[:, None], col_idx] out = data[torch.arange(num_codebooks)[:, None], col_idx]
@ -27,7 +27,7 @@ class HiggsTokenizer(nn.Module):
self.device = device self.device = device
self.dtypes = [torch.float32] self.dtypes = [torch.float32]
here = os.path.dirname(__file__) here = os.path.dirname(__file__)
tokenizer_path = os.path.join(here, "higgs_text_tokenizer") tokenizer_path = os.path.join(here, "higgs_text_tokenizer")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
@ -65,16 +65,16 @@ class HiggsTokenizer(nn.Module):
# due to instability issues, I had to convert the audio tokenizer to float32, avoiding outputing nans # due to instability issues, I had to convert the audio tokenizer to float32, avoiding outputing nans
self.audio_tokenizer = self.audio_tokenizer.to(self.dtype) self.audio_tokenizer = self.audio_tokenizer.to(self.dtype)
torch.cuda.synchronize() torch.cuda.synchronize()
for audio in audio_tokens: for audio in audio_tokens:
vq_code = revert_delay_pattern_vectorized(audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1] vq_code = revert_delay_pattern_vectorized(audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0] wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
outputs.append(wv_numpy) outputs.append(wv_numpy)
return (None, {"waveform": torch.stack(outputs, dim = 0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only return (None, {"waveform": torch.stack(outputs, dim = 0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only
def load_state_dict(self, sd, strict = False): def load_state_dict(self, sd, strict = False):
return self.audio_tokenizer.load_state_dict(sd, strict = strict) return self.audio_tokenizer.load_state_dict(sd, strict = strict)
def state_dict(self): def state_dict(self):
return self.audio_tokenizer.state_dict() return self.audio_tokenizer.state_dict()

View File

@ -244,10 +244,10 @@ class Attention(nn.Module):
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
out = self.o_proj(output) out = self.o_proj(output)
if past_key_value is not None: if past_key_value is not None:
return out, past_key_value return out, past_key_value
return out return out
class MLP(nn.Module): class MLP(nn.Module):

View File

@ -9,19 +9,17 @@ import folder_paths
import os import os
import io import io
import json import json
import base64
import random import random
import hashlib import hashlib
import numpy as np import numpy as np
import node_helpers import node_helpers
from io import BytesIO
from comfy.cli_args import args from comfy.cli_args import args
from comfy.comfy_types import IO from comfy.comfy_types import IO
from comfy.comfy_types import FileLocator from comfy.comfy_types import FileLocator
from dataclasses import asdict from dataclasses import asdict
from comfy.ldm.higgsv2.loudness import loudness from comfy.ldm.higgsv2.loudness import loudness
from comfy.ldm.higgsv2.preprocess import ( from comfy.ldm.higgsv2.preprocess import (
prepare_chatml_sample, Message, ChatMLSample, ChatMLDatasetSample, AudioContent, TextContent, transcript_normalize prepare_chatml_sample, Message, ChatMLSample, ChatMLDatasetSample, AudioContent, transcript_normalize
) )
AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>" AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>"
@ -29,7 +27,7 @@ AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>"
MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech. MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech.
If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice. If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.
If no speaker tag is present, select a suitable voice on your own.""" If no speaker tag is present, select a suitable voice on your own."""
class LoudnessNormalization: class LoudnessNormalization:
CATEGORY = "audio" CATEGORY = "audio"
@ -42,7 +40,7 @@ class LoudnessNormalization:
"block_size": ("FLOAT", {"default": 0.400, "min": 0.1, "max": 1.0, "step": 0.05}), "block_size": ("FLOAT", {"default": 0.400, "min": 0.1, "max": 1.0, "step": 0.05}),
"loudness_threshold": ("FLOAT", {"default": -23.0, "min": -70.0, "max": 0.0, "step": 0.5, "loudness_threshold": ("FLOAT", {"default": -23.0, "min": -70.0, "max": 0.0, "step": 0.5,
"tooltip": "Target loudness in LUFS. Common values are -23.0 (broadcast), -14.0 (streaming)."})}} "tooltip": "Target loudness in LUFS. Common values are -23.0 (broadcast), -14.0 (streaming)."})}}
def normalize(self, audio, loudness_threshold, block_size): def normalize(self, audio, loudness_threshold, block_size):
sampling_rate = audio["sample_rate"] sampling_rate = audio["sample_rate"]
waveform = audio["waveform"] waveform = audio["waveform"]
@ -75,10 +73,10 @@ def prepare_chatml_input(
if audio_content.raw_audio.device != next(clip.audio_tokenizer.parameters()).device: if audio_content.raw_audio.device != next(clip.audio_tokenizer.parameters()).device:
audio_content.raw_audio = audio_content.raw_audio.to(next(clip.audio_tokenizer.parameters()).device) audio_content.raw_audio = audio_content.raw_audio.to(next(clip.audio_tokenizer.parameters()).device)
audio_ids = clip.audio_tokenizer.encode(audio_content.raw_audio, sampling_rate) audio_ids = clip.audio_tokenizer.encode(audio_content.raw_audio, sampling_rate)
audio_ids_l.append(audio_ids.squeeze(0)) audio_ids_l.append(audio_ids.squeeze(0))
if len(audio_ids_l) > 0: if len(audio_ids_l) > 0:
audio_ids_start = torch.tensor( audio_ids_start = torch.tensor(
np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])), np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
@ -115,18 +113,18 @@ def prepare_chatml_input(
def postprocess_chatml(text: str) -> str: def postprocess_chatml(text: str) -> str:
speakers = set(re.findall(r'\[SPEAKER\d+\]', text)) speakers = set(re.findall(r'\[SPEAKER\d+\]', text))
skip_recon = True skip_recon = True
if len(speakers) > 1: if len(speakers) > 1:
parts = text.split('<|eot_id|>') parts = text.split('<|eot_id|>')
# keep the first <|eot_id|> and the last one # keep the first <|eot_id|> and the last one
first_eot = parts[0] + '<|eot_id|>' first_eot = parts[0] + '<|eot_id|>'
middle_parts = ''.join(parts[1:-1]) middle_parts = ''.join(parts[1:-1])
last_eot = '<|eot_id|>' + parts[-1] last_eot = '<|eot_id|>' + parts[-1]
text = first_eot + middle_parts + last_eot text = first_eot + middle_parts + last_eot
skip_recon = False skip_recon = False
return text, skip_recon return text, skip_recon
class CreateChatMLSample: class CreateChatMLSample:
@ -165,30 +163,30 @@ class CreateChatMLSample:
if audio is not None: if audio is not None:
clip.load_model() clip.load_model()
if hasattr(clip, "cond_stage_model"): if hasattr(clip, "cond_stage_model"):
clip = clip.cond_stage_model clip = clip.cond_stage_model
text = transcript_normalize(text) text = transcript_normalize(text)
messages = [] messages = []
lines = text.splitlines() lines = text.splitlines()
sampling_rate = False sampling_rate = False
current_role = None current_role = None
collecting_system = False collecting_system = False
system_buffer = [] system_buffer = []
for line in lines: for line in lines:
line = line.strip() line = line.strip()
if not line: if not line:
continue continue
# system start # system start
if line.lower().startswith("system:"): if line.lower().startswith("system:"):
collecting_system = True collecting_system = True
system_buffer.append(line[len("system:"):].strip()) system_buffer.append(line[len("system:"):].strip())
continue continue
# while collecting system prompt # while collecting system prompt
if collecting_system: if collecting_system:
system_buffer.append(line) system_buffer.append(line)
@ -198,7 +196,7 @@ class CreateChatMLSample:
system_buffer = [] system_buffer = []
collecting_system = False collecting_system = False
continue continue
# speaker lines SPEAKER-0: text # speaker lines SPEAKER-0: text
match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE) match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE)
if match: if match:
@ -245,12 +243,12 @@ class CreateChatMLSample:
chat_ml_sample, chat_ml_sample,
clip.tokenizer, clip.tokenizer,
) )
if audio is None: if audio is None:
audio_contents = None audio_contents = None
out = prepare_chatml_input(clip, input_tokens, audio_contents, sampling_rate = sampling_rate) out = prepare_chatml_input(clip, input_tokens, audio_contents, sampling_rate = sampling_rate)
return (out,) return (out,)
class EmptyLatentAudio: class EmptyLatentAudio:
def __init__(self): def __init__(self):
self.device = comfy.model_management.intermediate_device() self.device = comfy.model_management.intermediate_device()
@ -612,7 +610,7 @@ NODE_CLASS_MAPPINGS = {
"PreviewAudio": PreviewAudio, "PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio, "ConditioningStableAudio": ConditioningStableAudio,
"LoudnessNormalization": LoudnessNormalization, "LoudnessNormalization": LoudnessNormalization,
"CreateChatMLSample": CreateChatMLSample "CreateChatMLSample": CreateChatMLSample,
"RecordAudio": RecordAudio, "RecordAudio": RecordAudio,
} }
@ -626,6 +624,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"SaveAudioMP3": "Save Audio (MP3)", "SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)", "SaveAudioOpus": "Save Audio (Opus)",
"LoudnessNormalization": "Loudness Normalization", "LoudnessNormalization": "Loudness Normalization",
"CreateChatMLSample": "Create ChatML Sample" "CreateChatMLSample": "Create ChatML Sample",
"RecordAudio": "Record Audio", "RecordAudio": "Record Audio",
} }

View File

@ -1571,7 +1571,7 @@ class AutoRegressiveGeneration:
"do_sample": ("BOOLEAN", {"default": False, "tooltip": "Add randomness in decoding the tokens."}), "do_sample": ("BOOLEAN", {"default": False, "tooltip": "Add randomness in decoding the tokens."}),
} }
} }
RETURN_TYPES = ("TOKENS",) RETURN_TYPES = ("TOKENS",)
FUNCTION = "generate" FUNCTION = "generate"
@ -1582,7 +1582,7 @@ class AutoRegressiveGeneration:
def generate(self, model, input_ids, seed, max_new_length, min_new_length, top_k, top_p, temperature, do_sample): def generate(self, model, input_ids, seed, max_new_length, min_new_length, top_k, top_p, temperature, do_sample):
return (auto_sample(self, model, input_ids, max_new_length, min_new_length, top_k, top_p, temperature, do_sample, seed = seed),) return (auto_sample(self, model, input_ids, max_new_length, min_new_length, top_k, top_p, temperature, do_sample, seed = seed),)
class DecodeTokens: class DecodeTokens:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -1591,7 +1591,7 @@ class DecodeTokens:
"clip": (IO.CLIP, {"tooltip": "The model used for generation."}), "clip": (IO.CLIP, {"tooltip": "The model used for generation."}),
"tokens": ("TOKENS", ),} "tokens": ("TOKENS", ),}
} }
FUNCTION = "decode" FUNCTION = "decode"
CATEGORY = "conditioning" CATEGORY = "conditioning"
RETURN_TYPES = ("TEXT", "AUDIO") RETURN_TYPES = ("TEXT", "AUDIO")