mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 04:10:15 +08:00
styling fixes
This commit is contained in:
parent
6e9335d638
commit
57c15f970c
@ -19,27 +19,27 @@ def estimate_autoregressive_vram(
|
|||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
dtype = torch.float16,
|
dtype = torch.float16,
|
||||||
intermediate_factor: float = 4.0,
|
intermediate_factor: float = 4.0,
|
||||||
device = torch.device('cuda')
|
device = torch.device('cuda')
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|
||||||
dtype_size = torch.finfo(dtype).bits // 8
|
dtype_size = torch.finfo(dtype).bits // 8
|
||||||
kv_cache_bytes = num_layers * max_seq_len * hidden_dim * 2 * batch_size * dtype_size
|
kv_cache_bytes = num_layers * max_seq_len * hidden_dim * 2 * batch_size * dtype_size
|
||||||
|
|
||||||
# we only calculate hidden states in cuda graphs, so we don't care about the output logits
|
# we only calculate hidden states in cuda graphs, so we don't care about the output logits
|
||||||
input_bytes = output_bytes = batch_size * max_seq_len * hidden_dim * dtype_size
|
input_bytes = output_bytes = batch_size * max_seq_len * hidden_dim * dtype_size
|
||||||
|
|
||||||
# rough calculation for activation sizes
|
# rough calculation for activation sizes
|
||||||
intermediate_bytes = intermediate_factor * output_bytes
|
intermediate_bytes = intermediate_factor * output_bytes
|
||||||
|
|
||||||
total_estimated = kv_cache_bytes + input_bytes + output_bytes + intermediate_bytes
|
total_estimated = kv_cache_bytes + input_bytes + output_bytes + intermediate_bytes
|
||||||
|
|
||||||
# get vram info
|
# get vram info
|
||||||
free_vram = get_free_memory(device)
|
free_vram = get_free_memory(device)
|
||||||
minimum_vram = minimum_inference_memory()
|
minimum_vram = minimum_inference_memory()
|
||||||
|
|
||||||
enough_vram = free_vram - minimum_vram >= total_estimated
|
enough_vram = free_vram - minimum_vram >= total_estimated
|
||||||
|
|
||||||
return enough_vram
|
return enough_vram
|
||||||
|
|
||||||
class TopKLogits:
|
class TopKLogits:
|
||||||
@ -64,7 +64,7 @@ class TemperatureLogitsWarper:
|
|||||||
def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
scores_processed = scores / self.temperature
|
scores_processed = scores / self.temperature
|
||||||
return scores_processed
|
return scores_processed
|
||||||
|
|
||||||
class TopPLogitsWarper:
|
class TopPLogitsWarper:
|
||||||
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
top_p = float(top_p)
|
top_p = float(top_p)
|
||||||
@ -175,7 +175,7 @@ class GenerationConfig:
|
|||||||
|
|
||||||
config_dict = {key: value for key, value in config_dict.items() if value is not None}
|
config_dict = {key: value for key, value in config_dict.items() if value is not None}
|
||||||
valid_fields = {f.name for f in fields(cls)}
|
valid_fields = {f.name for f in fields(cls)}
|
||||||
|
|
||||||
filtered_args = {k: v for k, v in {**config_dict, **kwargs}.items() if k in valid_fields}
|
filtered_args = {k: v for k, v in {**config_dict, **kwargs}.items() if k in valid_fields}
|
||||||
|
|
||||||
generation_config = cls(**filtered_args)
|
generation_config = cls(**filtered_args)
|
||||||
@ -216,7 +216,7 @@ class AutoRegressiveGeneration:
|
|||||||
self.model.cache_config = self.cache_config
|
self.model.cache_config = self.cache_config
|
||||||
|
|
||||||
self.kv_caches = {
|
self.kv_caches = {
|
||||||
|
|
||||||
length: StaticCache(
|
length: StaticCache(
|
||||||
config=self.cache_config,
|
config=self.cache_config,
|
||||||
max_batch_size = self.cache_config.max_batch,
|
max_batch_size = self.cache_config.max_batch,
|
||||||
@ -234,8 +234,8 @@ class AutoRegressiveGeneration:
|
|||||||
|
|
||||||
# cuda graphs only help if input shapes are constant
|
# cuda graphs only help if input shapes are constant
|
||||||
if (
|
if (
|
||||||
device == "cuda"
|
device == "cuda"
|
||||||
and hasattr(model, "capture_model")
|
and hasattr(model, "capture_model")
|
||||||
and self.model.cache_implementation == "static"
|
and self.model.cache_implementation == "static"
|
||||||
and self.model.use_kv_buckets
|
and self.model.use_kv_buckets
|
||||||
and enough_vram
|
and enough_vram
|
||||||
@ -247,7 +247,7 @@ class AutoRegressiveGeneration:
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length: int = 1024, min_new_length = 0,
|
def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length: int = 1024, min_new_length = 0,
|
||||||
top_k: int = 50, top_p: float = 1.0, temperature: float = 1.0, do_sample: bool = False, seed = None, **kwargs):
|
top_k: int = 50, top_p: float = 1.0, temperature: float = 1.0, do_sample: bool = False, seed = None, **kwargs):
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
torch_generator = torch.Generator(device = input_ids.device).manual_seed(seed)
|
torch_generator = torch.Generator(device = input_ids.device).manual_seed(seed)
|
||||||
else:
|
else:
|
||||||
@ -335,7 +335,7 @@ class AutoRegressiveGeneration:
|
|||||||
# TODO: have a default self._sample fn and a default check if the model supports autoregGen or not
|
# TODO: have a default self._sample fn and a default check if the model supports autoregGen or not
|
||||||
if not hasattr(self.model, "_sample"):
|
if not hasattr(self.model, "_sample"):
|
||||||
raise ValueError("Model doesn't support AutoRegressive Generation!")
|
raise ValueError("Model doesn't support AutoRegressive Generation!")
|
||||||
|
|
||||||
self._prepare_kv_caches()
|
self._prepare_kv_caches()
|
||||||
|
|
||||||
result = self.model._sample(
|
result = self.model._sample(
|
||||||
@ -347,7 +347,7 @@ class AutoRegressiveGeneration:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _prepare_kv_caches(self):
|
def _prepare_kv_caches(self):
|
||||||
for kv_cache in self.kv_caches.values():
|
for kv_cache in self.kv_caches.values():
|
||||||
kv_cache.reset()
|
kv_cache.reset()
|
||||||
@ -357,13 +357,13 @@ class AutoRegressiveGeneration:
|
|||||||
return GenerationSampling.BEAM_SAMPLING
|
return GenerationSampling.BEAM_SAMPLING
|
||||||
else:
|
else:
|
||||||
return GenerationSampling.GREEDY_SEARCH
|
return GenerationSampling.GREEDY_SEARCH
|
||||||
|
|
||||||
def _prepare_generated_length(
|
def _prepare_generated_length(
|
||||||
self,
|
self,
|
||||||
generation_config: GenerationConfig,
|
generation_config: GenerationConfig,
|
||||||
input_ids_length,
|
input_ids_length,
|
||||||
):
|
):
|
||||||
|
|
||||||
""" max_length = user_input_id_tokens + generation_max_length """
|
""" max_length = user_input_id_tokens + generation_max_length """
|
||||||
|
|
||||||
if generation_config.max_new_length is not None:
|
if generation_config.max_new_length is not None:
|
||||||
@ -374,11 +374,11 @@ class AutoRegressiveGeneration:
|
|||||||
generation_config.min_length = generation_config.min_new_length + input_ids_length
|
generation_config.min_length = generation_config.min_new_length + input_ids_length
|
||||||
|
|
||||||
return generation_config
|
return generation_config
|
||||||
|
|
||||||
def _get_cache(
|
def _get_cache(
|
||||||
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
|
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
|
||||||
) -> Cache:
|
) -> Cache:
|
||||||
|
|
||||||
assert cache_implementation == "static", f"Only 'static' cache is supported, got {cache_implementation}"
|
assert cache_implementation == "static", f"Only 'static' cache is supported, got {cache_implementation}"
|
||||||
|
|
||||||
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
|
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
|
||||||
@ -412,7 +412,7 @@ class AutoRegressiveGeneration:
|
|||||||
|
|
||||||
return self.model._cache
|
return self.model._cache
|
||||||
|
|
||||||
|
|
||||||
def _prepare_cache_for_generation(
|
def _prepare_cache_for_generation(
|
||||||
self,
|
self,
|
||||||
generation_config: GenerationConfig,
|
generation_config: GenerationConfig,
|
||||||
@ -466,7 +466,7 @@ class AutoRegressiveGeneration:
|
|||||||
model_kwargs = generation_config.update(**kwargs)
|
model_kwargs = generation_config.update(**kwargs)
|
||||||
|
|
||||||
return generation_config, model_kwargs
|
return generation_config, model_kwargs
|
||||||
|
|
||||||
def _validate_generated_length(self, generation_config: GenerationConfig, input_ids_length):
|
def _validate_generated_length(self, generation_config: GenerationConfig, input_ids_length):
|
||||||
"""Performs validation related to the resulting generated length"""
|
"""Performs validation related to the resulting generated length"""
|
||||||
|
|
||||||
@ -498,7 +498,7 @@ class AutoRegressiveGeneration:
|
|||||||
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
|
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
|
||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _expand_inputs_for_generation(
|
def _expand_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
expand_size: int = 1,
|
expand_size: int = 1,
|
||||||
@ -526,13 +526,13 @@ class AutoRegressiveGeneration:
|
|||||||
model_kwargs = _expand_dict_for_generation(model_kwargs)
|
model_kwargs = _expand_dict_for_generation(model_kwargs)
|
||||||
|
|
||||||
return input_ids, model_kwargs
|
return input_ids, model_kwargs
|
||||||
|
|
||||||
def _prepare_special_tokens(
|
def _prepare_special_tokens(
|
||||||
self,
|
self,
|
||||||
generation_config: GenerationConfig,
|
generation_config: GenerationConfig,
|
||||||
device: Optional[Union[torch.device, str]] = None,
|
device: Optional[Union[torch.device, str]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
def _tensor_or_none(token, device=None):
|
def _tensor_or_none(token, device=None):
|
||||||
if token is None:
|
if token is None:
|
||||||
return token
|
return token
|
||||||
@ -564,7 +564,7 @@ class AutoRegressiveGeneration:
|
|||||||
generation_config: GenerationConfig,
|
generation_config: GenerationConfig,
|
||||||
model_kwargs: dict[str, Any],
|
model_kwargs: dict[str, Any],
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
|
|
||||||
pad_token_id = generation_config._pad_token_tensor
|
pad_token_id = generation_config._pad_token_tensor
|
||||||
eos_token_id = generation_config._eos_token_tensor
|
eos_token_id = generation_config._eos_token_tensor
|
||||||
|
|
||||||
@ -593,12 +593,12 @@ class AutoRegressiveGeneration:
|
|||||||
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
|
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
|
||||||
)
|
)
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0, top_k=50, top_p=1.0, temperature=1.0, do_sample = False, seed=None, **kwargs):
|
def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0, top_k=50, top_p=1.0, temperature=1.0, do_sample = False, seed=None, **kwargs):
|
||||||
# to work with BaseModel
|
# to work with BaseModel
|
||||||
if hasattr(patcher, "model") and hasattr(patcher.model, "diffusion_model"):
|
if hasattr(patcher, "model") and hasattr(patcher.model, "diffusion_model"):
|
||||||
model = patcher.model.diffusion_model
|
model = patcher.model.diffusion_model
|
||||||
|
|
||||||
if node._cached_autoregressive_sampler is None or node._cached_autoregressive_sampler.model is not model:
|
if node._cached_autoregressive_sampler is None or node._cached_autoregressive_sampler.model is not model:
|
||||||
if model.device != patcher.load_device:
|
if model.device != patcher.load_device:
|
||||||
model = model.to(patcher.load_device, dtype=model.dtype)
|
model = model.to(patcher.load_device, dtype=model.dtype)
|
||||||
@ -610,7 +610,7 @@ def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0,
|
|||||||
kwargs.update({k: v for k, v in input_ids.items() if k != "input_ids"})
|
kwargs.update({k: v for k, v in input_ids.items() if k != "input_ids"})
|
||||||
else:
|
else:
|
||||||
main_input_ids = input_ids
|
main_input_ids = input_ids
|
||||||
|
|
||||||
device = node._cached_autoregressive_sampler.device
|
device = node._cached_autoregressive_sampler.device
|
||||||
|
|
||||||
main_input_ids = main_input_ids.to(device)
|
main_input_ids = main_input_ids.to(device)
|
||||||
|
|||||||
@ -22,7 +22,7 @@ class CUDAGraphRunner(nn.Module):
|
|||||||
|
|
||||||
def capture(self, *args, **kwargs):
|
def capture(self, *args, **kwargs):
|
||||||
assert self._graph is None
|
assert self._graph is None
|
||||||
|
|
||||||
for _ in range(_NUM_WARMUP_ITERS):
|
for _ in range(_NUM_WARMUP_ITERS):
|
||||||
self.model(*args, **kwargs)
|
self.model(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -125,7 +125,7 @@ class IIRfilter(object):
|
|||||||
@property
|
@property
|
||||||
def b_and_a(self):
|
def b_and_a(self):
|
||||||
return self.generate_coefficients()
|
return self.generate_coefficients()
|
||||||
|
|
||||||
class Meter(torch.nn.Module):
|
class Meter(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -227,7 +227,7 @@ class Meter(torch.nn.Module):
|
|||||||
return unfolded
|
return unfolded
|
||||||
|
|
||||||
def integrated_loudness(self, data: torch.Tensor):
|
def integrated_loudness(self, data: torch.Tensor):
|
||||||
|
|
||||||
if not torch.is_tensor(data):
|
if not torch.is_tensor(data):
|
||||||
data = torch.from_numpy(data).float()
|
data = torch.from_numpy(data).float()
|
||||||
else:
|
else:
|
||||||
@ -291,10 +291,10 @@ class Meter(torch.nn.Module):
|
|||||||
|
|
||||||
def loudness(
|
def loudness(
|
||||||
audio_data, sample_rate: int, target_loudness: int, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs
|
audio_data, sample_rate: int, target_loudness: int, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs
|
||||||
):
|
):
|
||||||
MIN_LOUDNESS = -70
|
MIN_LOUDNESS = -70
|
||||||
device = audio_data.device
|
device = audio_data.device
|
||||||
|
|
||||||
original_length = audio_data.shape[-1]
|
original_length = audio_data.shape[-1]
|
||||||
signal_duration = original_length / sample_rate
|
signal_duration = original_length / sample_rate
|
||||||
|
|
||||||
|
|||||||
@ -19,8 +19,8 @@ from collections import defaultdict, OrderedDict
|
|||||||
from typing import Optional, Tuple, Union, List
|
from typing import Optional, Tuple, Union, List
|
||||||
|
|
||||||
class GenerationMode(Enum):
|
class GenerationMode(Enum):
|
||||||
TEXT = 0
|
TEXT = 0
|
||||||
AUDIO_INIT = 1
|
AUDIO_INIT = 1
|
||||||
AUDIO_IN_PROGRESS = 2
|
AUDIO_IN_PROGRESS = 2
|
||||||
|
|
||||||
def _ignore_causal_mask_sdpa(
|
def _ignore_causal_mask_sdpa(
|
||||||
@ -413,7 +413,7 @@ class HiggsAudioModel(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, device = None, dtype = None, operations = None, **kwargs):
|
def __init__(self, device = None, dtype = None, operations = None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.padding_idx = kwargs["pad_token_id"]
|
self.padding_idx = kwargs["pad_token_id"]
|
||||||
self.audio_in_token_idx = kwargs["audio_in_token_idx"]
|
self.audio_in_token_idx = kwargs["audio_in_token_idx"]
|
||||||
self.audio_out_token_idx = kwargs["audio_out_token_idx"]
|
self.audio_out_token_idx = kwargs["audio_out_token_idx"]
|
||||||
@ -439,7 +439,7 @@ class HiggsAudioModel(nn.Module):
|
|||||||
|
|
||||||
self.audio_out_bos_token_id = 128013
|
self.audio_out_bos_token_id = 128013
|
||||||
self.audio_eos_token_id = 128012
|
self.audio_eos_token_id = 128012
|
||||||
|
|
||||||
text_config = kwargs["text_config"]
|
text_config = kwargs["text_config"]
|
||||||
llama_config = Llama2Config(num_attention_heads = text_config["num_attention_heads"],
|
llama_config = Llama2Config(num_attention_heads = text_config["num_attention_heads"],
|
||||||
num_key_value_heads = text_config["num_key_value_heads"],
|
num_key_value_heads = text_config["num_key_value_heads"],
|
||||||
@ -616,7 +616,7 @@ class HiggsAudioModel(nn.Module):
|
|||||||
next_audio_tokens = None
|
next_audio_tokens = None
|
||||||
|
|
||||||
return next_tokens, next_audio_tokens
|
return next_tokens, next_audio_tokens
|
||||||
|
|
||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
self,
|
self,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
@ -677,7 +677,7 @@ class HiggsAudioModel(nn.Module):
|
|||||||
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True))
|
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True))
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
def _embed_audio_ids(self, audio_ids):
|
def _embed_audio_ids(self, audio_ids):
|
||||||
codebook_shift = (
|
codebook_shift = (
|
||||||
torch.arange(self.config["audio_num_codebooks"], device=audio_ids.device) * self.audio_codebook_size
|
torch.arange(self.config["audio_num_codebooks"], device=audio_ids.device) * self.audio_codebook_size
|
||||||
@ -712,7 +712,7 @@ class HiggsAudioModel(nn.Module):
|
|||||||
)
|
)
|
||||||
audio_attention_mask = attention_mask.masked_fill(no_audio_out_mask, min_dtype)
|
audio_attention_mask = attention_mask.masked_fill(no_audio_out_mask, min_dtype)
|
||||||
return fast_forward_attention_mask, audio_attention_mask
|
return fast_forward_attention_mask, audio_attention_mask
|
||||||
|
|
||||||
def _forward_core(
|
def _forward_core(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -728,7 +728,7 @@ class HiggsAudioModel(nn.Module):
|
|||||||
is_using_cuda_graph: Optional[bool] = False,
|
is_using_cuda_graph: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
position_id_offset = cache_position[0] if use_cache else 0
|
position_id_offset = cache_position[0] if use_cache else 0
|
||||||
position_embeddings = self.rotary_emb(hidden_states, position_ids + position_id_offset)
|
position_embeddings = self.rotary_emb(hidden_states, position_ids + position_id_offset)
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
for decoder_layer in self.layers:
|
||||||
@ -927,7 +927,7 @@ class HiggsAudioModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def _update_model_kwargs_for_generation(
|
def _update_model_kwargs_for_generation(
|
||||||
self,
|
self,
|
||||||
outputs,
|
outputs,
|
||||||
@ -956,13 +956,13 @@ class HiggsAudioModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
def _copy_kv_cache(self, from_cache: Cache, to_cache: Cache):
|
def _copy_kv_cache(self, from_cache: Cache, to_cache: Cache):
|
||||||
from_cache_size = from_cache.get_max_cache_shape()
|
from_cache_size = from_cache.get_max_cache_shape()
|
||||||
assert to_cache.get_max_cache_shape() >= from_cache_size, (
|
assert to_cache.get_max_cache_shape() >= from_cache_size, (
|
||||||
f"The target cache size {to_cache.get_max_cache_shape()} is smaller than the source cache size {from_cache_size}."
|
f"The target cache size {to_cache.get_max_cache_shape()} is smaller than the source cache size {from_cache_size}."
|
||||||
)
|
)
|
||||||
|
|
||||||
n_layers = self.num_hidden_layers
|
n_layers = self.num_hidden_layers
|
||||||
|
|
||||||
for i in range(n_layers):
|
for i in range(n_layers):
|
||||||
@ -977,7 +977,7 @@ class HiggsAudioModel(nn.Module):
|
|||||||
self.cache_config.head_dim),
|
self.cache_config.head_dim),
|
||||||
device=self.device, dtype=self.dtype
|
device=self.device, dtype=self.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
if getattr(to_layer, "values", None) is None:
|
if getattr(to_layer, "values", None) is None:
|
||||||
to_layer.values = torch.zeros(
|
to_layer.values = torch.zeros(
|
||||||
(self.cache_config.max_batch, self.cache_config.num_key_value_heads,
|
(self.cache_config.max_batch, self.cache_config.num_key_value_heads,
|
||||||
@ -1011,7 +1011,7 @@ class HiggsAudioModel(nn.Module):
|
|||||||
f"The current sequence length {current_sequence_length} is larger than "
|
f"The current sequence length {current_sequence_length} is larger than "
|
||||||
f"all past key values buckets {past_key_values_buckets.keys()}."
|
f"all past key values buckets {past_key_values_buckets.keys()}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _sample(
|
def _sample(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
@ -1020,7 +1020,7 @@ class HiggsAudioModel(nn.Module):
|
|||||||
past_key_values_buckets: Optional[OrderedDict[int, Cache]],
|
past_key_values_buckets: Optional[OrderedDict[int, Cache]],
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
# code supports only non-mixed batchs
|
# code supports only non-mixed batchs
|
||||||
|
|
||||||
audio_out_bos_token_id = generation_config.generation_kwargs.get("audio_out_bos_token_id", None)
|
audio_out_bos_token_id = generation_config.generation_kwargs.get("audio_out_bos_token_id", None)
|
||||||
@ -1069,7 +1069,7 @@ class HiggsAudioModel(nn.Module):
|
|||||||
|
|
||||||
while not this_peer_finished:
|
while not this_peer_finished:
|
||||||
eos_token_tensor = torch.tensor([self.config["text_config"]["eos_token_id"]], device=input_ids.device)
|
eos_token_tensor = torch.tensor([self.config["text_config"]["eos_token_id"]], device=input_ids.device)
|
||||||
|
|
||||||
if input_ids[0][-1] == audio_out_bos_token_id:
|
if input_ids[0][-1] == audio_out_bos_token_id:
|
||||||
generation_mode = GenerationMode.AUDIO_INIT
|
generation_mode = GenerationMode.AUDIO_INIT
|
||||||
elif input_ids[0][-1] == self.audio_out_token_idx:
|
elif input_ids[0][-1] == self.audio_out_token_idx:
|
||||||
@ -1211,7 +1211,7 @@ class HiggsAudioModel(nn.Module):
|
|||||||
pbar.update(pbar.total - pbar.current)
|
pbar.update(pbar.total - pbar.current)
|
||||||
|
|
||||||
return audio_sequences
|
return audio_sequences
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
@ -1222,7 +1222,7 @@ class HiggsAudioModel(nn.Module):
|
|||||||
generation_functions = None,
|
generation_functions = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
if generation_config is None:
|
if generation_config is None:
|
||||||
generation_config = GenerationConfig()
|
generation_config = GenerationConfig()
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import (
|
from typing import (
|
||||||
Dict, List, Optional, Union, Tuple, MutableMapping, Any, Mapping, Collection, get_type_hints, get_args, get_origin
|
Dict, List, Optional, Union, Tuple, MutableMapping, Any, Mapping, Collection, get_type_hints, get_args, get_origin
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -153,13 +153,13 @@ def from_dict(data_class, data):
|
|||||||
value = _build_value(type_=field_type, data=data[key])
|
value = _build_value(type_=field_type, data=data[key])
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
raise ValueError(error)
|
raise ValueError(error)
|
||||||
|
|
||||||
if not is_instance(value, field_type):
|
if not is_instance(value, field_type):
|
||||||
raise ValueError((
|
raise ValueError((
|
||||||
f'wrong value type for field "{field.name}" - should be "{field_type}" '
|
f'wrong value type for field "{field.name}" - should be "{field_type}" '
|
||||||
f'instead of value "{value}" of type "{type(value)}"'
|
f'instead of value "{value}" of type "{type(value)}"'
|
||||||
))
|
))
|
||||||
|
|
||||||
init_values[field.name] = value
|
init_values[field.name] = value
|
||||||
|
|
||||||
instance = data_class(**init_values)
|
instance = data_class(**init_values)
|
||||||
@ -273,7 +273,7 @@ def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
|
|||||||
try:
|
try:
|
||||||
if not isinstance(sample, ChatMLSample):
|
if not isinstance(sample, ChatMLSample):
|
||||||
|
|
||||||
# replacing pd.isna
|
# replacing pd.isna
|
||||||
def is_nan(x):
|
def is_nan(x):
|
||||||
if isinstance(x, float):
|
if isinstance(x, float):
|
||||||
return math.isnan(x)
|
return math.isnan(x)
|
||||||
@ -282,7 +282,7 @@ def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
|
|||||||
if isinstance(x, torch.Tensor) and x.numel() == 1:
|
if isinstance(x, torch.Tensor) and x.numel() == 1:
|
||||||
return torch.isnan(x).item()
|
return torch.isnan(x).item()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if "speaker" in sample and is_nan(sample["speaker"]):
|
if "speaker" in sample and is_nan(sample["speaker"]):
|
||||||
sample["speaker"] = None
|
sample["speaker"] = None
|
||||||
if "start_index" in sample and is_nan(sample["start_index"]):
|
if "start_index" in sample and is_nan(sample["start_index"]):
|
||||||
@ -489,7 +489,7 @@ class HiggsAudioSampleCollator:
|
|||||||
def _process_and_duplicate_audio_tokens(
|
def _process_and_duplicate_audio_tokens(
|
||||||
self, input_ids: torch.Tensor, audio_idx: int, wv: torch.Tensor, labels: Optional[torch.Tensor] = None
|
self, input_ids: torch.Tensor, audio_idx: int, wv: torch.Tensor, labels: Optional[torch.Tensor] = None
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||||
|
|
||||||
total_samples = len(wv)
|
total_samples = len(wv)
|
||||||
num_chunks = math.ceil(total_samples / self.chunk_size_samples)
|
num_chunks = math.ceil(total_samples / self.chunk_size_samples)
|
||||||
|
|
||||||
@ -583,15 +583,15 @@ class HiggsAudioSampleCollator:
|
|||||||
audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)
|
audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)
|
||||||
|
|
||||||
if len(audio_in_ids_l) > 0:
|
if len(audio_in_ids_l) > 0:
|
||||||
|
|
||||||
# I tried to remove the for-loop in original implementation
|
# I tried to remove the for-loop in original implementation
|
||||||
# but to do batching with padding caused problem so I turned it into a list compre.
|
# but to do batching with padding caused problem so I turned it into a list compre.
|
||||||
lengths = [seg.shape[1] for seg in audio_in_ids_l]
|
lengths = [seg.shape[1] for seg in audio_in_ids_l]
|
||||||
aug_lengths = [l + 2 for l in lengths]
|
aug_lengths = [l + 2 for l in lengths]
|
||||||
audio_in_ids_start = torch.cumsum(
|
audio_in_ids_start = torch.cumsum(
|
||||||
torch.tensor([0] + aug_lengths[:-1], dtype=torch.long), dim=0
|
torch.tensor([0] + aug_lengths[:-1], dtype=torch.long), dim=0
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.disable_audio_codes_transform:
|
if self.disable_audio_codes_transform:
|
||||||
audio_in_ids = torch.cat(audio_in_ids_l, dim=1).long()
|
audio_in_ids = torch.cat(audio_in_ids_l, dim=1).long()
|
||||||
else:
|
else:
|
||||||
@ -607,7 +607,7 @@ class HiggsAudioSampleCollator:
|
|||||||
if self.use_delay_pattern:
|
if self.use_delay_pattern:
|
||||||
with_tokens = [
|
with_tokens = [
|
||||||
build_delay_pattern_mask(
|
build_delay_pattern_mask(
|
||||||
tok.unsqueeze(0),
|
tok.unsqueeze(0),
|
||||||
bos_token_id=self.audio_stream_bos_id,
|
bos_token_id=self.audio_stream_bos_id,
|
||||||
pad_token_id=self.audio_stream_eos_id
|
pad_token_id=self.audio_stream_eos_id
|
||||||
)[0]
|
)[0]
|
||||||
|
|||||||
@ -160,7 +160,7 @@ class DACDecoder(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.model(x)
|
return self.model(x)
|
||||||
|
|
||||||
class Conv1d1x1:
|
class Conv1d1x1:
|
||||||
def __new__(cls, in_channels, out_channels, bias=True, device=None, dtype=None, operations=None):
|
def __new__(cls, in_channels, out_channels, bias=True, device=None, dtype=None, operations=None):
|
||||||
operations = operations or nn
|
operations = operations or nn
|
||||||
@ -168,7 +168,7 @@ class Conv1d1x1:
|
|||||||
in_channels, out_channels, kernel_size=1,
|
in_channels, out_channels, kernel_size=1,
|
||||||
bias=bias, device=device, dtype=dtype
|
bias=bias, device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
class Conv1d(nn.Module):
|
class Conv1d(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -203,7 +203,7 @@ class Conv1d(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class ConvTranspose1d(nn.Module):
|
class ConvTranspose1d(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -237,7 +237,7 @@ class ConvTranspose1d(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.deconv(x)
|
x = self.deconv(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class ResidualUnit(nn.Module):
|
class ResidualUnit(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -283,7 +283,7 @@ class EncoderBlock(nn.Module):
|
|||||||
self.conv = Conv1d(
|
self.conv = Conv1d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size = kernel_size,
|
kernel_size = kernel_size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
device = device, dtype = dtype, operations = operations
|
device = device, dtype = dtype, operations = operations
|
||||||
@ -435,7 +435,7 @@ class Decoder(nn.Module):
|
|||||||
x = self.conv_blocks[i](x)
|
x = self.conv_blocks[i](x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class HiggsAudioFeatureExtractor(nn.Module):
|
class HiggsAudioFeatureExtractor(nn.Module):
|
||||||
def __init__(self, sampling_rate=16000):
|
def __init__(self, sampling_rate=16000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -493,7 +493,7 @@ class EuclideanCodebook(nn.Module):
|
|||||||
embed = self.embed.t()
|
embed = self.embed.t()
|
||||||
if x.dtype != embed.dtype:
|
if x.dtype != embed.dtype:
|
||||||
x = x.to(embed.dtype)
|
x = x.to(embed.dtype)
|
||||||
|
|
||||||
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
|
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
|
||||||
embed_ind = dist.max(dim=-1).indices
|
embed_ind = dist.max(dim=-1).indices
|
||||||
return embed_ind
|
return embed_ind
|
||||||
@ -520,14 +520,14 @@ class EuclideanCodebook(nn.Module):
|
|||||||
return quantize
|
return quantize
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
orig_shape = x.shape # [B, T, D]
|
orig_shape = x.shape # [B, T, D]
|
||||||
flat = x.view(-1, x.shape[-1]) # [B*T, D]
|
flat = x.view(-1, x.shape[-1]) # [B*T, D]
|
||||||
|
|
||||||
embed_ind = self.quantize(flat)
|
embed_ind = self.quantize(flat)
|
||||||
embed_ind = self.postprocess_emb(embed_ind, orig_shape)
|
embed_ind = self.postprocess_emb(embed_ind, orig_shape)
|
||||||
# now embed_ind has shape [B, T]
|
# now embed_ind has shape [B, T]
|
||||||
|
|
||||||
quantize = self.dequantize(embed_ind)
|
quantize = self.dequantize(embed_ind)
|
||||||
# quantize: [B, T, D]
|
# quantize: [B, T, D]
|
||||||
|
|
||||||
return quantize, embed_ind
|
return quantize, embed_ind
|
||||||
@ -636,9 +636,9 @@ class ResidualVectorQuantization(nn.Module):
|
|||||||
quantized = F.embedding(embed_id, codebook_weight).transpose(1, 2) # (B, D, T)
|
quantized = F.embedding(embed_id, codebook_weight).transpose(1, 2) # (B, D, T)
|
||||||
quantized = F.linear(quantized.transpose(1, 2), proj_weight, proj_biases).transpose(1, 2)
|
quantized = F.linear(quantized.transpose(1, 2), proj_weight, proj_biases).transpose(1, 2)
|
||||||
return quantized
|
return quantized
|
||||||
|
|
||||||
codebook_weights = torch.stack([q._codebook.embed for q in self.layers]) # (n_codebooks, vocab_size, D)
|
codebook_weights = torch.stack([q._codebook.embed for q in self.layers]) # (n_codebooks, vocab_size, D)
|
||||||
proj_weights = torch.stack([q.project_out.weight for q in self.layers])
|
proj_weights = torch.stack([q.project_out.weight for q in self.layers])
|
||||||
|
|
||||||
quantized = vmap(decode_one)(codebook_weights, proj_weights, q_indices, biases)
|
quantized = vmap(decode_one)(codebook_weights, proj_weights, q_indices, biases)
|
||||||
|
|
||||||
@ -705,7 +705,7 @@ class ResidualVectorQuantizer(nn.Module):
|
|||||||
"""Decode the given codes to the quantized representation."""
|
"""Decode the given codes to the quantized representation."""
|
||||||
quantized = self.vq.decode(codes)
|
quantized = self.vq.decode(codes)
|
||||||
return quantized
|
return quantized
|
||||||
|
|
||||||
class HiggsAudioTokenizer(nn.Module):
|
class HiggsAudioTokenizer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -786,7 +786,7 @@ class HiggsAudioTokenizer(nn.Module):
|
|||||||
x = x[:, 0, :]
|
x = x[:, 0, :]
|
||||||
x = F.pad(x, (160, 160))
|
x = F.pad(x, (160, 160))
|
||||||
target = self.semantic_model(x, output_hidden_states=True).hidden_states
|
target = self.semantic_model(x, output_hidden_states=True).hidden_states
|
||||||
target = torch.stack(target, dim=1)
|
target = torch.stack(target, dim=1)
|
||||||
|
|
||||||
target = target.mean(1)
|
target = target.mean(1)
|
||||||
|
|
||||||
|
|||||||
@ -413,7 +413,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["qkv_bias"] = False
|
dit_config["qkv_bias"] = False
|
||||||
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
|
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if "{}layers.27.audio_post_attention_layernorm.weight".format(key_prefix) in state_dict_keys:
|
if "{}layers.27.audio_post_attention_layernorm.weight".format(key_prefix) in state_dict_keys:
|
||||||
|
|
||||||
autoregressive_config = {}
|
autoregressive_config = {}
|
||||||
|
|||||||
@ -1307,8 +1307,8 @@ class Higgsv2(supported_models_base.BASE):
|
|||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.Higgsv2(self, device=device)
|
out = model_base.Higgsv2(self, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def clip_target(self, state_dict = {}):
|
def clip_target(self, state_dict = {}):
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.higgsv2.DummyTokenizer, comfy.text_encoders.higgsv2.HiggsTokenizer)
|
return supported_models_base.ClipTarget(comfy.text_encoders.higgsv2.DummyTokenizer, comfy.text_encoders.higgsv2.HiggsTokenizer)
|
||||||
|
|
||||||
|
|||||||
@ -13,7 +13,7 @@ class DummyTokenizer:
|
|||||||
def revert_delay_pattern_vectorized(data: torch.Tensor) -> torch.Tensor:
|
def revert_delay_pattern_vectorized(data: torch.Tensor) -> torch.Tensor:
|
||||||
num_codebooks, total_len = data.shape
|
num_codebooks, total_len = data.shape
|
||||||
seq_len = total_len - num_codebooks + 1
|
seq_len = total_len - num_codebooks + 1
|
||||||
|
|
||||||
col_idx = torch.arange(seq_len, device=data.device)[None, :] \
|
col_idx = torch.arange(seq_len, device=data.device)[None, :] \
|
||||||
+ torch.arange(num_codebooks, device=data.device)[:, None]
|
+ torch.arange(num_codebooks, device=data.device)[:, None]
|
||||||
out = data[torch.arange(num_codebooks)[:, None], col_idx]
|
out = data[torch.arange(num_codebooks)[:, None], col_idx]
|
||||||
@ -27,7 +27,7 @@ class HiggsTokenizer(nn.Module):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.dtypes = [torch.float32]
|
self.dtypes = [torch.float32]
|
||||||
|
|
||||||
here = os.path.dirname(__file__)
|
here = os.path.dirname(__file__)
|
||||||
tokenizer_path = os.path.join(here, "higgs_text_tokenizer")
|
tokenizer_path = os.path.join(here, "higgs_text_tokenizer")
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||||
@ -65,16 +65,16 @@ class HiggsTokenizer(nn.Module):
|
|||||||
# due to instability issues, I had to convert the audio tokenizer to float32, avoiding outputing nans
|
# due to instability issues, I had to convert the audio tokenizer to float32, avoiding outputing nans
|
||||||
self.audio_tokenizer = self.audio_tokenizer.to(self.dtype)
|
self.audio_tokenizer = self.audio_tokenizer.to(self.dtype)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
for audio in audio_tokens:
|
for audio in audio_tokens:
|
||||||
vq_code = revert_delay_pattern_vectorized(audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
|
vq_code = revert_delay_pattern_vectorized(audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
|
||||||
wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
|
wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
|
||||||
outputs.append(wv_numpy)
|
outputs.append(wv_numpy)
|
||||||
|
|
||||||
return (None, {"waveform": torch.stack(outputs, dim = 0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only
|
return (None, {"waveform": torch.stack(outputs, dim = 0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only
|
||||||
|
|
||||||
def load_state_dict(self, sd, strict = False):
|
def load_state_dict(self, sd, strict = False):
|
||||||
return self.audio_tokenizer.load_state_dict(sd, strict = strict)
|
return self.audio_tokenizer.load_state_dict(sd, strict = strict)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return self.audio_tokenizer.state_dict()
|
return self.audio_tokenizer.state_dict()
|
||||||
|
|||||||
@ -244,10 +244,10 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||||
out = self.o_proj(output)
|
out = self.o_proj(output)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
return out, past_key_value
|
return out, past_key_value
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
|
|||||||
@ -9,19 +9,17 @@ import folder_paths
|
|||||||
import os
|
import os
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import base64
|
|
||||||
import random
|
import random
|
||||||
import hashlib
|
import hashlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from io import BytesIO
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy.comfy_types import IO
|
from comfy.comfy_types import IO
|
||||||
from comfy.comfy_types import FileLocator
|
from comfy.comfy_types import FileLocator
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from comfy.ldm.higgsv2.loudness import loudness
|
from comfy.ldm.higgsv2.loudness import loudness
|
||||||
from comfy.ldm.higgsv2.preprocess import (
|
from comfy.ldm.higgsv2.preprocess import (
|
||||||
prepare_chatml_sample, Message, ChatMLSample, ChatMLDatasetSample, AudioContent, TextContent, transcript_normalize
|
prepare_chatml_sample, Message, ChatMLSample, ChatMLDatasetSample, AudioContent, transcript_normalize
|
||||||
)
|
)
|
||||||
|
|
||||||
AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>"
|
AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>"
|
||||||
@ -29,7 +27,7 @@ AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>"
|
|||||||
MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech.
|
MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech.
|
||||||
If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.
|
If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.
|
||||||
If no speaker tag is present, select a suitable voice on your own."""
|
If no speaker tag is present, select a suitable voice on your own."""
|
||||||
|
|
||||||
class LoudnessNormalization:
|
class LoudnessNormalization:
|
||||||
|
|
||||||
CATEGORY = "audio"
|
CATEGORY = "audio"
|
||||||
@ -42,7 +40,7 @@ class LoudnessNormalization:
|
|||||||
"block_size": ("FLOAT", {"default": 0.400, "min": 0.1, "max": 1.0, "step": 0.05}),
|
"block_size": ("FLOAT", {"default": 0.400, "min": 0.1, "max": 1.0, "step": 0.05}),
|
||||||
"loudness_threshold": ("FLOAT", {"default": -23.0, "min": -70.0, "max": 0.0, "step": 0.5,
|
"loudness_threshold": ("FLOAT", {"default": -23.0, "min": -70.0, "max": 0.0, "step": 0.5,
|
||||||
"tooltip": "Target loudness in LUFS. Common values are -23.0 (broadcast), -14.0 (streaming)."})}}
|
"tooltip": "Target loudness in LUFS. Common values are -23.0 (broadcast), -14.0 (streaming)."})}}
|
||||||
|
|
||||||
def normalize(self, audio, loudness_threshold, block_size):
|
def normalize(self, audio, loudness_threshold, block_size):
|
||||||
sampling_rate = audio["sample_rate"]
|
sampling_rate = audio["sample_rate"]
|
||||||
waveform = audio["waveform"]
|
waveform = audio["waveform"]
|
||||||
@ -75,10 +73,10 @@ def prepare_chatml_input(
|
|||||||
|
|
||||||
if audio_content.raw_audio.device != next(clip.audio_tokenizer.parameters()).device:
|
if audio_content.raw_audio.device != next(clip.audio_tokenizer.parameters()).device:
|
||||||
audio_content.raw_audio = audio_content.raw_audio.to(next(clip.audio_tokenizer.parameters()).device)
|
audio_content.raw_audio = audio_content.raw_audio.to(next(clip.audio_tokenizer.parameters()).device)
|
||||||
|
|
||||||
audio_ids = clip.audio_tokenizer.encode(audio_content.raw_audio, sampling_rate)
|
audio_ids = clip.audio_tokenizer.encode(audio_content.raw_audio, sampling_rate)
|
||||||
audio_ids_l.append(audio_ids.squeeze(0))
|
audio_ids_l.append(audio_ids.squeeze(0))
|
||||||
|
|
||||||
if len(audio_ids_l) > 0:
|
if len(audio_ids_l) > 0:
|
||||||
audio_ids_start = torch.tensor(
|
audio_ids_start = torch.tensor(
|
||||||
np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
|
np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
|
||||||
@ -115,18 +113,18 @@ def prepare_chatml_input(
|
|||||||
def postprocess_chatml(text: str) -> str:
|
def postprocess_chatml(text: str) -> str:
|
||||||
speakers = set(re.findall(r'\[SPEAKER\d+\]', text))
|
speakers = set(re.findall(r'\[SPEAKER\d+\]', text))
|
||||||
skip_recon = True
|
skip_recon = True
|
||||||
|
|
||||||
if len(speakers) > 1:
|
if len(speakers) > 1:
|
||||||
parts = text.split('<|eot_id|>')
|
parts = text.split('<|eot_id|>')
|
||||||
|
|
||||||
# keep the first <|eot_id|> and the last one
|
# keep the first <|eot_id|> and the last one
|
||||||
first_eot = parts[0] + '<|eot_id|>'
|
first_eot = parts[0] + '<|eot_id|>'
|
||||||
middle_parts = ''.join(parts[1:-1])
|
middle_parts = ''.join(parts[1:-1])
|
||||||
last_eot = '<|eot_id|>' + parts[-1]
|
last_eot = '<|eot_id|>' + parts[-1]
|
||||||
|
|
||||||
text = first_eot + middle_parts + last_eot
|
text = first_eot + middle_parts + last_eot
|
||||||
skip_recon = False
|
skip_recon = False
|
||||||
|
|
||||||
return text, skip_recon
|
return text, skip_recon
|
||||||
|
|
||||||
class CreateChatMLSample:
|
class CreateChatMLSample:
|
||||||
@ -165,30 +163,30 @@ class CreateChatMLSample:
|
|||||||
|
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
clip.load_model()
|
clip.load_model()
|
||||||
|
|
||||||
if hasattr(clip, "cond_stage_model"):
|
if hasattr(clip, "cond_stage_model"):
|
||||||
clip = clip.cond_stage_model
|
clip = clip.cond_stage_model
|
||||||
|
|
||||||
text = transcript_normalize(text)
|
text = transcript_normalize(text)
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
lines = text.splitlines()
|
lines = text.splitlines()
|
||||||
sampling_rate = False
|
sampling_rate = False
|
||||||
current_role = None
|
current_role = None
|
||||||
collecting_system = False
|
collecting_system = False
|
||||||
system_buffer = []
|
system_buffer = []
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if not line:
|
if not line:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# system start
|
# system start
|
||||||
if line.lower().startswith("system:"):
|
if line.lower().startswith("system:"):
|
||||||
collecting_system = True
|
collecting_system = True
|
||||||
system_buffer.append(line[len("system:"):].strip())
|
system_buffer.append(line[len("system:"):].strip())
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# while collecting system prompt
|
# while collecting system prompt
|
||||||
if collecting_system:
|
if collecting_system:
|
||||||
system_buffer.append(line)
|
system_buffer.append(line)
|
||||||
@ -198,7 +196,7 @@ class CreateChatMLSample:
|
|||||||
system_buffer = []
|
system_buffer = []
|
||||||
collecting_system = False
|
collecting_system = False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# speaker lines SPEAKER-0: text
|
# speaker lines SPEAKER-0: text
|
||||||
match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE)
|
match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE)
|
||||||
if match:
|
if match:
|
||||||
@ -245,12 +243,12 @@ class CreateChatMLSample:
|
|||||||
chat_ml_sample,
|
chat_ml_sample,
|
||||||
clip.tokenizer,
|
clip.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
if audio is None:
|
if audio is None:
|
||||||
audio_contents = None
|
audio_contents = None
|
||||||
out = prepare_chatml_input(clip, input_tokens, audio_contents, sampling_rate = sampling_rate)
|
out = prepare_chatml_input(clip, input_tokens, audio_contents, sampling_rate = sampling_rate)
|
||||||
return (out,)
|
return (out,)
|
||||||
|
|
||||||
class EmptyLatentAudio:
|
class EmptyLatentAudio:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.device = comfy.model_management.intermediate_device()
|
self.device = comfy.model_management.intermediate_device()
|
||||||
@ -612,7 +610,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"PreviewAudio": PreviewAudio,
|
"PreviewAudio": PreviewAudio,
|
||||||
"ConditioningStableAudio": ConditioningStableAudio,
|
"ConditioningStableAudio": ConditioningStableAudio,
|
||||||
"LoudnessNormalization": LoudnessNormalization,
|
"LoudnessNormalization": LoudnessNormalization,
|
||||||
"CreateChatMLSample": CreateChatMLSample
|
"CreateChatMLSample": CreateChatMLSample,
|
||||||
"RecordAudio": RecordAudio,
|
"RecordAudio": RecordAudio,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -626,6 +624,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"SaveAudioMP3": "Save Audio (MP3)",
|
"SaveAudioMP3": "Save Audio (MP3)",
|
||||||
"SaveAudioOpus": "Save Audio (Opus)",
|
"SaveAudioOpus": "Save Audio (Opus)",
|
||||||
"LoudnessNormalization": "Loudness Normalization",
|
"LoudnessNormalization": "Loudness Normalization",
|
||||||
"CreateChatMLSample": "Create ChatML Sample"
|
"CreateChatMLSample": "Create ChatML Sample",
|
||||||
"RecordAudio": "Record Audio",
|
"RecordAudio": "Record Audio",
|
||||||
}
|
}
|
||||||
|
|||||||
6
nodes.py
6
nodes.py
@ -1571,7 +1571,7 @@ class AutoRegressiveGeneration:
|
|||||||
"do_sample": ("BOOLEAN", {"default": False, "tooltip": "Add randomness in decoding the tokens."}),
|
"do_sample": ("BOOLEAN", {"default": False, "tooltip": "Add randomness in decoding the tokens."}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("TOKENS",)
|
RETURN_TYPES = ("TOKENS",)
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
@ -1582,7 +1582,7 @@ class AutoRegressiveGeneration:
|
|||||||
|
|
||||||
def generate(self, model, input_ids, seed, max_new_length, min_new_length, top_k, top_p, temperature, do_sample):
|
def generate(self, model, input_ids, seed, max_new_length, min_new_length, top_k, top_p, temperature, do_sample):
|
||||||
return (auto_sample(self, model, input_ids, max_new_length, min_new_length, top_k, top_p, temperature, do_sample, seed = seed),)
|
return (auto_sample(self, model, input_ids, max_new_length, min_new_length, top_k, top_p, temperature, do_sample, seed = seed),)
|
||||||
|
|
||||||
class DecodeTokens:
|
class DecodeTokens:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -1591,7 +1591,7 @@ class DecodeTokens:
|
|||||||
"clip": (IO.CLIP, {"tooltip": "The model used for generation."}),
|
"clip": (IO.CLIP, {"tooltip": "The model used for generation."}),
|
||||||
"tokens": ("TOKENS", ),}
|
"tokens": ("TOKENS", ),}
|
||||||
}
|
}
|
||||||
|
|
||||||
FUNCTION = "decode"
|
FUNCTION = "decode"
|
||||||
CATEGORY = "conditioning"
|
CATEGORY = "conditioning"
|
||||||
RETURN_TYPES = ("TEXT", "AUDIO")
|
RETURN_TYPES = ("TEXT", "AUDIO")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user