mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-16 01:00:49 +08:00
styling fixes
This commit is contained in:
parent
6e9335d638
commit
57c15f970c
@ -19,27 +19,27 @@ def estimate_autoregressive_vram(
|
||||
max_seq_len: int,
|
||||
batch_size: int = 1,
|
||||
dtype = torch.float16,
|
||||
intermediate_factor: float = 4.0,
|
||||
intermediate_factor: float = 4.0,
|
||||
device = torch.device('cuda')
|
||||
) -> bool:
|
||||
|
||||
|
||||
dtype_size = torch.finfo(dtype).bits // 8
|
||||
kv_cache_bytes = num_layers * max_seq_len * hidden_dim * 2 * batch_size * dtype_size
|
||||
|
||||
# we only calculate hidden states in cuda graphs, so we don't care about the output logits
|
||||
input_bytes = output_bytes = batch_size * max_seq_len * hidden_dim * dtype_size
|
||||
|
||||
# we only calculate hidden states in cuda graphs, so we don't care about the output logits
|
||||
input_bytes = output_bytes = batch_size * max_seq_len * hidden_dim * dtype_size
|
||||
|
||||
# rough calculation for activation sizes
|
||||
intermediate_bytes = intermediate_factor * output_bytes
|
||||
|
||||
|
||||
total_estimated = kv_cache_bytes + input_bytes + output_bytes + intermediate_bytes
|
||||
|
||||
|
||||
# get vram info
|
||||
free_vram = get_free_memory(device)
|
||||
minimum_vram = minimum_inference_memory()
|
||||
|
||||
|
||||
enough_vram = free_vram - minimum_vram >= total_estimated
|
||||
|
||||
|
||||
return enough_vram
|
||||
|
||||
class TopKLogits:
|
||||
@ -64,7 +64,7 @@ class TemperatureLogitsWarper:
|
||||
def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
scores_processed = scores / self.temperature
|
||||
return scores_processed
|
||||
|
||||
|
||||
class TopPLogitsWarper:
|
||||
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
top_p = float(top_p)
|
||||
@ -175,7 +175,7 @@ class GenerationConfig:
|
||||
|
||||
config_dict = {key: value for key, value in config_dict.items() if value is not None}
|
||||
valid_fields = {f.name for f in fields(cls)}
|
||||
|
||||
|
||||
filtered_args = {k: v for k, v in {**config_dict, **kwargs}.items() if k in valid_fields}
|
||||
|
||||
generation_config = cls(**filtered_args)
|
||||
@ -216,7 +216,7 @@ class AutoRegressiveGeneration:
|
||||
self.model.cache_config = self.cache_config
|
||||
|
||||
self.kv_caches = {
|
||||
|
||||
|
||||
length: StaticCache(
|
||||
config=self.cache_config,
|
||||
max_batch_size = self.cache_config.max_batch,
|
||||
@ -234,8 +234,8 @@ class AutoRegressiveGeneration:
|
||||
|
||||
# cuda graphs only help if input shapes are constant
|
||||
if (
|
||||
device == "cuda"
|
||||
and hasattr(model, "capture_model")
|
||||
device == "cuda"
|
||||
and hasattr(model, "capture_model")
|
||||
and self.model.cache_implementation == "static"
|
||||
and self.model.use_kv_buckets
|
||||
and enough_vram
|
||||
@ -247,7 +247,7 @@ class AutoRegressiveGeneration:
|
||||
@torch.inference_mode()
|
||||
def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length: int = 1024, min_new_length = 0,
|
||||
top_k: int = 50, top_p: float = 1.0, temperature: float = 1.0, do_sample: bool = False, seed = None, **kwargs):
|
||||
|
||||
|
||||
if seed is not None:
|
||||
torch_generator = torch.Generator(device = input_ids.device).manual_seed(seed)
|
||||
else:
|
||||
@ -335,7 +335,7 @@ class AutoRegressiveGeneration:
|
||||
# TODO: have a default self._sample fn and a default check if the model supports autoregGen or not
|
||||
if not hasattr(self.model, "_sample"):
|
||||
raise ValueError("Model doesn't support AutoRegressive Generation!")
|
||||
|
||||
|
||||
self._prepare_kv_caches()
|
||||
|
||||
result = self.model._sample(
|
||||
@ -347,7 +347,7 @@ class AutoRegressiveGeneration:
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _prepare_kv_caches(self):
|
||||
for kv_cache in self.kv_caches.values():
|
||||
kv_cache.reset()
|
||||
@ -357,13 +357,13 @@ class AutoRegressiveGeneration:
|
||||
return GenerationSampling.BEAM_SAMPLING
|
||||
else:
|
||||
return GenerationSampling.GREEDY_SEARCH
|
||||
|
||||
|
||||
def _prepare_generated_length(
|
||||
self,
|
||||
generation_config: GenerationConfig,
|
||||
input_ids_length,
|
||||
):
|
||||
|
||||
|
||||
""" max_length = user_input_id_tokens + generation_max_length """
|
||||
|
||||
if generation_config.max_new_length is not None:
|
||||
@ -374,11 +374,11 @@ class AutoRegressiveGeneration:
|
||||
generation_config.min_length = generation_config.min_new_length + input_ids_length
|
||||
|
||||
return generation_config
|
||||
|
||||
|
||||
def _get_cache(
|
||||
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
|
||||
) -> Cache:
|
||||
|
||||
|
||||
assert cache_implementation == "static", f"Only 'static' cache is supported, got {cache_implementation}"
|
||||
|
||||
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
|
||||
@ -412,7 +412,7 @@ class AutoRegressiveGeneration:
|
||||
|
||||
return self.model._cache
|
||||
|
||||
|
||||
|
||||
def _prepare_cache_for_generation(
|
||||
self,
|
||||
generation_config: GenerationConfig,
|
||||
@ -466,7 +466,7 @@ class AutoRegressiveGeneration:
|
||||
model_kwargs = generation_config.update(**kwargs)
|
||||
|
||||
return generation_config, model_kwargs
|
||||
|
||||
|
||||
def _validate_generated_length(self, generation_config: GenerationConfig, input_ids_length):
|
||||
"""Performs validation related to the resulting generated length"""
|
||||
|
||||
@ -498,7 +498,7 @@ class AutoRegressiveGeneration:
|
||||
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
|
||||
def _expand_inputs_for_generation(
|
||||
self,
|
||||
expand_size: int = 1,
|
||||
@ -526,13 +526,13 @@ class AutoRegressiveGeneration:
|
||||
model_kwargs = _expand_dict_for_generation(model_kwargs)
|
||||
|
||||
return input_ids, model_kwargs
|
||||
|
||||
|
||||
def _prepare_special_tokens(
|
||||
self,
|
||||
generation_config: GenerationConfig,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
):
|
||||
|
||||
|
||||
def _tensor_or_none(token, device=None):
|
||||
if token is None:
|
||||
return token
|
||||
@ -564,7 +564,7 @@ class AutoRegressiveGeneration:
|
||||
generation_config: GenerationConfig,
|
||||
model_kwargs: dict[str, Any],
|
||||
) -> torch.LongTensor:
|
||||
|
||||
|
||||
pad_token_id = generation_config._pad_token_tensor
|
||||
eos_token_id = generation_config._eos_token_tensor
|
||||
|
||||
@ -593,12 +593,12 @@ class AutoRegressiveGeneration:
|
||||
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
|
||||
)
|
||||
return attention_mask
|
||||
|
||||
|
||||
def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0, top_k=50, top_p=1.0, temperature=1.0, do_sample = False, seed=None, **kwargs):
|
||||
# to work with BaseModel
|
||||
if hasattr(patcher, "model") and hasattr(patcher.model, "diffusion_model"):
|
||||
model = patcher.model.diffusion_model
|
||||
|
||||
|
||||
if node._cached_autoregressive_sampler is None or node._cached_autoregressive_sampler.model is not model:
|
||||
if model.device != patcher.load_device:
|
||||
model = model.to(patcher.load_device, dtype=model.dtype)
|
||||
@ -610,7 +610,7 @@ def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0,
|
||||
kwargs.update({k: v for k, v in input_ids.items() if k != "input_ids"})
|
||||
else:
|
||||
main_input_ids = input_ids
|
||||
|
||||
|
||||
device = node._cached_autoregressive_sampler.device
|
||||
|
||||
main_input_ids = main_input_ids.to(device)
|
||||
|
||||
@ -22,7 +22,7 @@ class CUDAGraphRunner(nn.Module):
|
||||
|
||||
def capture(self, *args, **kwargs):
|
||||
assert self._graph is None
|
||||
|
||||
|
||||
for _ in range(_NUM_WARMUP_ITERS):
|
||||
self.model(*args, **kwargs)
|
||||
|
||||
|
||||
@ -125,7 +125,7 @@ class IIRfilter(object):
|
||||
@property
|
||||
def b_and_a(self):
|
||||
return self.generate_coefficients()
|
||||
|
||||
|
||||
class Meter(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -227,7 +227,7 @@ class Meter(torch.nn.Module):
|
||||
return unfolded
|
||||
|
||||
def integrated_loudness(self, data: torch.Tensor):
|
||||
|
||||
|
||||
if not torch.is_tensor(data):
|
||||
data = torch.from_numpy(data).float()
|
||||
else:
|
||||
@ -291,10 +291,10 @@ class Meter(torch.nn.Module):
|
||||
|
||||
def loudness(
|
||||
audio_data, sample_rate: int, target_loudness: int, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs
|
||||
):
|
||||
):
|
||||
MIN_LOUDNESS = -70
|
||||
device = audio_data.device
|
||||
|
||||
|
||||
original_length = audio_data.shape[-1]
|
||||
signal_duration = original_length / sample_rate
|
||||
|
||||
|
||||
@ -19,8 +19,8 @@ from collections import defaultdict, OrderedDict
|
||||
from typing import Optional, Tuple, Union, List
|
||||
|
||||
class GenerationMode(Enum):
|
||||
TEXT = 0
|
||||
AUDIO_INIT = 1
|
||||
TEXT = 0
|
||||
AUDIO_INIT = 1
|
||||
AUDIO_IN_PROGRESS = 2
|
||||
|
||||
def _ignore_causal_mask_sdpa(
|
||||
@ -413,7 +413,7 @@ class HiggsAudioModel(nn.Module):
|
||||
|
||||
def __init__(self, device = None, dtype = None, operations = None, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.padding_idx = kwargs["pad_token_id"]
|
||||
self.audio_in_token_idx = kwargs["audio_in_token_idx"]
|
||||
self.audio_out_token_idx = kwargs["audio_out_token_idx"]
|
||||
@ -439,7 +439,7 @@ class HiggsAudioModel(nn.Module):
|
||||
|
||||
self.audio_out_bos_token_id = 128013
|
||||
self.audio_eos_token_id = 128012
|
||||
|
||||
|
||||
text_config = kwargs["text_config"]
|
||||
llama_config = Llama2Config(num_attention_heads = text_config["num_attention_heads"],
|
||||
num_key_value_heads = text_config["num_key_value_heads"],
|
||||
@ -616,7 +616,7 @@ class HiggsAudioModel(nn.Module):
|
||||
next_audio_tokens = None
|
||||
|
||||
return next_tokens, next_audio_tokens
|
||||
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
@ -677,7 +677,7 @@ class HiggsAudioModel(nn.Module):
|
||||
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True))
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
def _embed_audio_ids(self, audio_ids):
|
||||
codebook_shift = (
|
||||
torch.arange(self.config["audio_num_codebooks"], device=audio_ids.device) * self.audio_codebook_size
|
||||
@ -712,7 +712,7 @@ class HiggsAudioModel(nn.Module):
|
||||
)
|
||||
audio_attention_mask = attention_mask.masked_fill(no_audio_out_mask, min_dtype)
|
||||
return fast_forward_attention_mask, audio_attention_mask
|
||||
|
||||
|
||||
def _forward_core(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -728,7 +728,7 @@ class HiggsAudioModel(nn.Module):
|
||||
is_using_cuda_graph: Optional[bool] = False,
|
||||
):
|
||||
|
||||
position_id_offset = cache_position[0] if use_cache else 0
|
||||
position_id_offset = cache_position[0] if use_cache else 0
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids + position_id_offset)
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
@ -927,7 +927,7 @@ class HiggsAudioModel(nn.Module):
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def _update_model_kwargs_for_generation(
|
||||
self,
|
||||
outputs,
|
||||
@ -956,13 +956,13 @@ class HiggsAudioModel(nn.Module):
|
||||
)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
|
||||
def _copy_kv_cache(self, from_cache: Cache, to_cache: Cache):
|
||||
from_cache_size = from_cache.get_max_cache_shape()
|
||||
assert to_cache.get_max_cache_shape() >= from_cache_size, (
|
||||
f"The target cache size {to_cache.get_max_cache_shape()} is smaller than the source cache size {from_cache_size}."
|
||||
)
|
||||
|
||||
|
||||
n_layers = self.num_hidden_layers
|
||||
|
||||
for i in range(n_layers):
|
||||
@ -977,7 +977,7 @@ class HiggsAudioModel(nn.Module):
|
||||
self.cache_config.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
|
||||
|
||||
if getattr(to_layer, "values", None) is None:
|
||||
to_layer.values = torch.zeros(
|
||||
(self.cache_config.max_batch, self.cache_config.num_key_value_heads,
|
||||
@ -1011,7 +1011,7 @@ class HiggsAudioModel(nn.Module):
|
||||
f"The current sequence length {current_sequence_length} is larger than "
|
||||
f"all past key values buckets {past_key_values_buckets.keys()}."
|
||||
)
|
||||
|
||||
|
||||
def _sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
@ -1020,7 +1020,7 @@ class HiggsAudioModel(nn.Module):
|
||||
past_key_values_buckets: Optional[OrderedDict[int, Cache]],
|
||||
**model_kwargs,
|
||||
):
|
||||
|
||||
|
||||
# code supports only non-mixed batchs
|
||||
|
||||
audio_out_bos_token_id = generation_config.generation_kwargs.get("audio_out_bos_token_id", None)
|
||||
@ -1069,7 +1069,7 @@ class HiggsAudioModel(nn.Module):
|
||||
|
||||
while not this_peer_finished:
|
||||
eos_token_tensor = torch.tensor([self.config["text_config"]["eos_token_id"]], device=input_ids.device)
|
||||
|
||||
|
||||
if input_ids[0][-1] == audio_out_bos_token_id:
|
||||
generation_mode = GenerationMode.AUDIO_INIT
|
||||
elif input_ids[0][-1] == self.audio_out_token_idx:
|
||||
@ -1211,7 +1211,7 @@ class HiggsAudioModel(nn.Module):
|
||||
pbar.update(pbar.total - pbar.current)
|
||||
|
||||
return audio_sequences
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
@ -1222,7 +1222,7 @@ class HiggsAudioModel(nn.Module):
|
||||
generation_functions = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
|
||||
if generation_config is None:
|
||||
generation_config = GenerationConfig()
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import (
|
||||
from typing import (
|
||||
Dict, List, Optional, Union, Tuple, MutableMapping, Any, Mapping, Collection, get_type_hints, get_args, get_origin
|
||||
)
|
||||
|
||||
@ -153,13 +153,13 @@ def from_dict(data_class, data):
|
||||
value = _build_value(type_=field_type, data=data[key])
|
||||
except Exception as error:
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
if not is_instance(value, field_type):
|
||||
raise ValueError((
|
||||
f'wrong value type for field "{field.name}" - should be "{field_type}" '
|
||||
f'instead of value "{value}" of type "{type(value)}"'
|
||||
))
|
||||
|
||||
|
||||
init_values[field.name] = value
|
||||
|
||||
instance = data_class(**init_values)
|
||||
@ -273,7 +273,7 @@ def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
|
||||
try:
|
||||
if not isinstance(sample, ChatMLSample):
|
||||
|
||||
# replacing pd.isna
|
||||
# replacing pd.isna
|
||||
def is_nan(x):
|
||||
if isinstance(x, float):
|
||||
return math.isnan(x)
|
||||
@ -282,7 +282,7 @@ def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
|
||||
if isinstance(x, torch.Tensor) and x.numel() == 1:
|
||||
return torch.isnan(x).item()
|
||||
return False
|
||||
|
||||
|
||||
if "speaker" in sample and is_nan(sample["speaker"]):
|
||||
sample["speaker"] = None
|
||||
if "start_index" in sample and is_nan(sample["start_index"]):
|
||||
@ -489,7 +489,7 @@ class HiggsAudioSampleCollator:
|
||||
def _process_and_duplicate_audio_tokens(
|
||||
self, input_ids: torch.Tensor, audio_idx: int, wv: torch.Tensor, labels: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||
|
||||
|
||||
total_samples = len(wv)
|
||||
num_chunks = math.ceil(total_samples / self.chunk_size_samples)
|
||||
|
||||
@ -583,15 +583,15 @@ class HiggsAudioSampleCollator:
|
||||
audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)
|
||||
|
||||
if len(audio_in_ids_l) > 0:
|
||||
|
||||
|
||||
# I tried to remove the for-loop in original implementation
|
||||
# but to do batching with padding caused problem so I turned it into a list compre.
|
||||
lengths = [seg.shape[1] for seg in audio_in_ids_l]
|
||||
aug_lengths = [l + 2 for l in lengths]
|
||||
lengths = [seg.shape[1] for seg in audio_in_ids_l]
|
||||
aug_lengths = [l + 2 for l in lengths]
|
||||
audio_in_ids_start = torch.cumsum(
|
||||
torch.tensor([0] + aug_lengths[:-1], dtype=torch.long), dim=0
|
||||
)
|
||||
|
||||
|
||||
if self.disable_audio_codes_transform:
|
||||
audio_in_ids = torch.cat(audio_in_ids_l, dim=1).long()
|
||||
else:
|
||||
@ -607,7 +607,7 @@ class HiggsAudioSampleCollator:
|
||||
if self.use_delay_pattern:
|
||||
with_tokens = [
|
||||
build_delay_pattern_mask(
|
||||
tok.unsqueeze(0),
|
||||
tok.unsqueeze(0),
|
||||
bos_token_id=self.audio_stream_bos_id,
|
||||
pad_token_id=self.audio_stream_eos_id
|
||||
)[0]
|
||||
|
||||
@ -160,7 +160,7 @@ class DACDecoder(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class Conv1d1x1:
|
||||
def __new__(cls, in_channels, out_channels, bias=True, device=None, dtype=None, operations=None):
|
||||
operations = operations or nn
|
||||
@ -168,7 +168,7 @@ class Conv1d1x1:
|
||||
in_channels, out_channels, kernel_size=1,
|
||||
bias=bias, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
|
||||
class Conv1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -203,7 +203,7 @@ class Conv1d(nn.Module):
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvTranspose1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -237,7 +237,7 @@ class ConvTranspose1d(nn.Module):
|
||||
def forward(self, x):
|
||||
x = self.deconv(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualUnit(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -283,7 +283,7 @@ class EncoderBlock(nn.Module):
|
||||
self.conv = Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size = kernel_size,
|
||||
kernel_size = kernel_size,
|
||||
stride=stride,
|
||||
bias=bias,
|
||||
device = device, dtype = dtype, operations = operations
|
||||
@ -435,7 +435,7 @@ class Decoder(nn.Module):
|
||||
x = self.conv_blocks[i](x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
class HiggsAudioFeatureExtractor(nn.Module):
|
||||
def __init__(self, sampling_rate=16000):
|
||||
super().__init__()
|
||||
@ -493,7 +493,7 @@ class EuclideanCodebook(nn.Module):
|
||||
embed = self.embed.t()
|
||||
if x.dtype != embed.dtype:
|
||||
x = x.to(embed.dtype)
|
||||
|
||||
|
||||
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
|
||||
embed_ind = dist.max(dim=-1).indices
|
||||
return embed_ind
|
||||
@ -520,14 +520,14 @@ class EuclideanCodebook(nn.Module):
|
||||
return quantize
|
||||
|
||||
def forward(self, x):
|
||||
orig_shape = x.shape # [B, T, D]
|
||||
orig_shape = x.shape # [B, T, D]
|
||||
flat = x.view(-1, x.shape[-1]) # [B*T, D]
|
||||
|
||||
embed_ind = self.quantize(flat)
|
||||
embed_ind = self.postprocess_emb(embed_ind, orig_shape)
|
||||
embed_ind = self.quantize(flat)
|
||||
embed_ind = self.postprocess_emb(embed_ind, orig_shape)
|
||||
# now embed_ind has shape [B, T]
|
||||
|
||||
quantize = self.dequantize(embed_ind)
|
||||
quantize = self.dequantize(embed_ind)
|
||||
# quantize: [B, T, D]
|
||||
|
||||
return quantize, embed_ind
|
||||
@ -636,9 +636,9 @@ class ResidualVectorQuantization(nn.Module):
|
||||
quantized = F.embedding(embed_id, codebook_weight).transpose(1, 2) # (B, D, T)
|
||||
quantized = F.linear(quantized.transpose(1, 2), proj_weight, proj_biases).transpose(1, 2)
|
||||
return quantized
|
||||
|
||||
|
||||
codebook_weights = torch.stack([q._codebook.embed for q in self.layers]) # (n_codebooks, vocab_size, D)
|
||||
proj_weights = torch.stack([q.project_out.weight for q in self.layers])
|
||||
proj_weights = torch.stack([q.project_out.weight for q in self.layers])
|
||||
|
||||
quantized = vmap(decode_one)(codebook_weights, proj_weights, q_indices, biases)
|
||||
|
||||
@ -705,7 +705,7 @@ class ResidualVectorQuantizer(nn.Module):
|
||||
"""Decode the given codes to the quantized representation."""
|
||||
quantized = self.vq.decode(codes)
|
||||
return quantized
|
||||
|
||||
|
||||
class HiggsAudioTokenizer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -786,7 +786,7 @@ class HiggsAudioTokenizer(nn.Module):
|
||||
x = x[:, 0, :]
|
||||
x = F.pad(x, (160, 160))
|
||||
target = self.semantic_model(x, output_hidden_states=True).hidden_states
|
||||
target = torch.stack(target, dim=1)
|
||||
target = torch.stack(target, dim=1)
|
||||
|
||||
target = target.mean(1)
|
||||
|
||||
|
||||
@ -413,7 +413,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["qkv_bias"] = False
|
||||
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
|
||||
if "{}layers.27.audio_post_attention_layernorm.weight".format(key_prefix) in state_dict_keys:
|
||||
|
||||
autoregressive_config = {}
|
||||
|
||||
@ -1307,8 +1307,8 @@ class Higgsv2(supported_models_base.BASE):
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Higgsv2(self, device=device)
|
||||
return out
|
||||
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict = {}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.higgsv2.DummyTokenizer, comfy.text_encoders.higgsv2.HiggsTokenizer)
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ class DummyTokenizer:
|
||||
def revert_delay_pattern_vectorized(data: torch.Tensor) -> torch.Tensor:
|
||||
num_codebooks, total_len = data.shape
|
||||
seq_len = total_len - num_codebooks + 1
|
||||
|
||||
|
||||
col_idx = torch.arange(seq_len, device=data.device)[None, :] \
|
||||
+ torch.arange(num_codebooks, device=data.device)[:, None]
|
||||
out = data[torch.arange(num_codebooks)[:, None], col_idx]
|
||||
@ -27,7 +27,7 @@ class HiggsTokenizer(nn.Module):
|
||||
self.device = device
|
||||
self.dtypes = [torch.float32]
|
||||
|
||||
here = os.path.dirname(__file__)
|
||||
here = os.path.dirname(__file__)
|
||||
tokenizer_path = os.path.join(here, "higgs_text_tokenizer")
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
@ -65,16 +65,16 @@ class HiggsTokenizer(nn.Module):
|
||||
# due to instability issues, I had to convert the audio tokenizer to float32, avoiding outputing nans
|
||||
self.audio_tokenizer = self.audio_tokenizer.to(self.dtype)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
for audio in audio_tokens:
|
||||
vq_code = revert_delay_pattern_vectorized(audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
|
||||
wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
|
||||
outputs.append(wv_numpy)
|
||||
|
||||
return (None, {"waveform": torch.stack(outputs, dim = 0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only
|
||||
|
||||
|
||||
def load_state_dict(self, sd, strict = False):
|
||||
return self.audio_tokenizer.load_state_dict(sd, strict = strict)
|
||||
|
||||
|
||||
def state_dict(self):
|
||||
return self.audio_tokenizer.state_dict()
|
||||
|
||||
@ -244,10 +244,10 @@ class Attention(nn.Module):
|
||||
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||
out = self.o_proj(output)
|
||||
|
||||
|
||||
if past_key_value is not None:
|
||||
return out, past_key_value
|
||||
|
||||
|
||||
return out
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
@ -9,19 +9,17 @@ import folder_paths
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import base64
|
||||
import random
|
||||
import hashlib
|
||||
import numpy as np
|
||||
import node_helpers
|
||||
from io import BytesIO
|
||||
from comfy.cli_args import args
|
||||
from comfy.comfy_types import IO
|
||||
from comfy.comfy_types import FileLocator
|
||||
from dataclasses import asdict
|
||||
from comfy.ldm.higgsv2.loudness import loudness
|
||||
from comfy.ldm.higgsv2.preprocess import (
|
||||
prepare_chatml_sample, Message, ChatMLSample, ChatMLDatasetSample, AudioContent, TextContent, transcript_normalize
|
||||
prepare_chatml_sample, Message, ChatMLSample, ChatMLDatasetSample, AudioContent, transcript_normalize
|
||||
)
|
||||
|
||||
AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>"
|
||||
@ -29,7 +27,7 @@ AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>"
|
||||
MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech.
|
||||
If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.
|
||||
If no speaker tag is present, select a suitable voice on your own."""
|
||||
|
||||
|
||||
class LoudnessNormalization:
|
||||
|
||||
CATEGORY = "audio"
|
||||
@ -42,7 +40,7 @@ class LoudnessNormalization:
|
||||
"block_size": ("FLOAT", {"default": 0.400, "min": 0.1, "max": 1.0, "step": 0.05}),
|
||||
"loudness_threshold": ("FLOAT", {"default": -23.0, "min": -70.0, "max": 0.0, "step": 0.5,
|
||||
"tooltip": "Target loudness in LUFS. Common values are -23.0 (broadcast), -14.0 (streaming)."})}}
|
||||
|
||||
|
||||
def normalize(self, audio, loudness_threshold, block_size):
|
||||
sampling_rate = audio["sample_rate"]
|
||||
waveform = audio["waveform"]
|
||||
@ -75,10 +73,10 @@ def prepare_chatml_input(
|
||||
|
||||
if audio_content.raw_audio.device != next(clip.audio_tokenizer.parameters()).device:
|
||||
audio_content.raw_audio = audio_content.raw_audio.to(next(clip.audio_tokenizer.parameters()).device)
|
||||
|
||||
|
||||
audio_ids = clip.audio_tokenizer.encode(audio_content.raw_audio, sampling_rate)
|
||||
audio_ids_l.append(audio_ids.squeeze(0))
|
||||
|
||||
|
||||
if len(audio_ids_l) > 0:
|
||||
audio_ids_start = torch.tensor(
|
||||
np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
|
||||
@ -115,18 +113,18 @@ def prepare_chatml_input(
|
||||
def postprocess_chatml(text: str) -> str:
|
||||
speakers = set(re.findall(r'\[SPEAKER\d+\]', text))
|
||||
skip_recon = True
|
||||
|
||||
|
||||
if len(speakers) > 1:
|
||||
parts = text.split('<|eot_id|>')
|
||||
|
||||
|
||||
# keep the first <|eot_id|> and the last one
|
||||
first_eot = parts[0] + '<|eot_id|>'
|
||||
middle_parts = ''.join(parts[1:-1])
|
||||
middle_parts = ''.join(parts[1:-1])
|
||||
last_eot = '<|eot_id|>' + parts[-1]
|
||||
|
||||
|
||||
text = first_eot + middle_parts + last_eot
|
||||
skip_recon = False
|
||||
|
||||
|
||||
return text, skip_recon
|
||||
|
||||
class CreateChatMLSample:
|
||||
@ -165,30 +163,30 @@ class CreateChatMLSample:
|
||||
|
||||
if audio is not None:
|
||||
clip.load_model()
|
||||
|
||||
|
||||
if hasattr(clip, "cond_stage_model"):
|
||||
clip = clip.cond_stage_model
|
||||
|
||||
text = transcript_normalize(text)
|
||||
|
||||
|
||||
messages = []
|
||||
lines = text.splitlines()
|
||||
sampling_rate = False
|
||||
current_role = None
|
||||
collecting_system = False
|
||||
system_buffer = []
|
||||
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
|
||||
# system start
|
||||
if line.lower().startswith("system:"):
|
||||
collecting_system = True
|
||||
system_buffer.append(line[len("system:"):].strip())
|
||||
continue
|
||||
|
||||
|
||||
# while collecting system prompt
|
||||
if collecting_system:
|
||||
system_buffer.append(line)
|
||||
@ -198,7 +196,7 @@ class CreateChatMLSample:
|
||||
system_buffer = []
|
||||
collecting_system = False
|
||||
continue
|
||||
|
||||
|
||||
# speaker lines SPEAKER-0: text
|
||||
match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE)
|
||||
if match:
|
||||
@ -245,12 +243,12 @@ class CreateChatMLSample:
|
||||
chat_ml_sample,
|
||||
clip.tokenizer,
|
||||
)
|
||||
|
||||
|
||||
if audio is None:
|
||||
audio_contents = None
|
||||
out = prepare_chatml_input(clip, input_tokens, audio_contents, sampling_rate = sampling_rate)
|
||||
return (out,)
|
||||
|
||||
|
||||
class EmptyLatentAudio:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
@ -612,7 +610,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"PreviewAudio": PreviewAudio,
|
||||
"ConditioningStableAudio": ConditioningStableAudio,
|
||||
"LoudnessNormalization": LoudnessNormalization,
|
||||
"CreateChatMLSample": CreateChatMLSample
|
||||
"CreateChatMLSample": CreateChatMLSample,
|
||||
"RecordAudio": RecordAudio,
|
||||
}
|
||||
|
||||
@ -626,6 +624,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SaveAudioMP3": "Save Audio (MP3)",
|
||||
"SaveAudioOpus": "Save Audio (Opus)",
|
||||
"LoudnessNormalization": "Loudness Normalization",
|
||||
"CreateChatMLSample": "Create ChatML Sample"
|
||||
"CreateChatMLSample": "Create ChatML Sample",
|
||||
"RecordAudio": "Record Audio",
|
||||
}
|
||||
|
||||
6
nodes.py
6
nodes.py
@ -1571,7 +1571,7 @@ class AutoRegressiveGeneration:
|
||||
"do_sample": ("BOOLEAN", {"default": False, "tooltip": "Add randomness in decoding the tokens."}),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
RETURN_TYPES = ("TOKENS",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
@ -1582,7 +1582,7 @@ class AutoRegressiveGeneration:
|
||||
|
||||
def generate(self, model, input_ids, seed, max_new_length, min_new_length, top_k, top_p, temperature, do_sample):
|
||||
return (auto_sample(self, model, input_ids, max_new_length, min_new_length, top_k, top_p, temperature, do_sample, seed = seed),)
|
||||
|
||||
|
||||
class DecodeTokens:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1591,7 +1591,7 @@ class DecodeTokens:
|
||||
"clip": (IO.CLIP, {"tooltip": "The model used for generation."}),
|
||||
"tokens": ("TOKENS", ),}
|
||||
}
|
||||
|
||||
|
||||
FUNCTION = "decode"
|
||||
CATEGORY = "conditioning"
|
||||
RETURN_TYPES = ("TEXT", "AUDIO")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user