From 57c15f970c84ab9b0752eaa5d5e9bd5933bbd0aa Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sat, 6 Sep 2025 01:17:04 +0300 Subject: [PATCH] styling fixes --- comfy/autoregressive_sampling.py | 60 +++++++++++++------------- comfy/ldm/higgsv2/cuda_graph_runner.py | 2 +- comfy/ldm/higgsv2/loudness.py | 8 ++-- comfy/ldm/higgsv2/model.py | 34 +++++++-------- comfy/ldm/higgsv2/preprocess.py | 22 +++++----- comfy/ldm/higgsv2/tokenizer.py | 30 ++++++------- comfy/model_detection.py | 2 +- comfy/supported_models.py | 4 +- comfy/text_encoders/higgsv2.py | 10 ++--- comfy/text_encoders/llama.py | 4 +- comfy_extras/nodes_audio.py | 42 +++++++++--------- nodes.py | 6 +-- 12 files changed, 111 insertions(+), 113 deletions(-) diff --git a/comfy/autoregressive_sampling.py b/comfy/autoregressive_sampling.py index 54885bf42..424907fb6 100644 --- a/comfy/autoregressive_sampling.py +++ b/comfy/autoregressive_sampling.py @@ -19,27 +19,27 @@ def estimate_autoregressive_vram( max_seq_len: int, batch_size: int = 1, dtype = torch.float16, - intermediate_factor: float = 4.0, + intermediate_factor: float = 4.0, device = torch.device('cuda') ) -> bool: - + dtype_size = torch.finfo(dtype).bits // 8 kv_cache_bytes = num_layers * max_seq_len * hidden_dim * 2 * batch_size * dtype_size - # we only calculate hidden states in cuda graphs, so we don't care about the output logits - input_bytes = output_bytes = batch_size * max_seq_len * hidden_dim * dtype_size - + # we only calculate hidden states in cuda graphs, so we don't care about the output logits + input_bytes = output_bytes = batch_size * max_seq_len * hidden_dim * dtype_size + # rough calculation for activation sizes intermediate_bytes = intermediate_factor * output_bytes - + total_estimated = kv_cache_bytes + input_bytes + output_bytes + intermediate_bytes - + # get vram info free_vram = get_free_memory(device) minimum_vram = minimum_inference_memory() - + enough_vram = free_vram - minimum_vram >= total_estimated - + return enough_vram class TopKLogits: @@ -64,7 +64,7 @@ class TemperatureLogitsWarper: def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: scores_processed = scores / self.temperature return scores_processed - + class TopPLogitsWarper: def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): top_p = float(top_p) @@ -175,7 +175,7 @@ class GenerationConfig: config_dict = {key: value for key, value in config_dict.items() if value is not None} valid_fields = {f.name for f in fields(cls)} - + filtered_args = {k: v for k, v in {**config_dict, **kwargs}.items() if k in valid_fields} generation_config = cls(**filtered_args) @@ -216,7 +216,7 @@ class AutoRegressiveGeneration: self.model.cache_config = self.cache_config self.kv_caches = { - + length: StaticCache( config=self.cache_config, max_batch_size = self.cache_config.max_batch, @@ -234,8 +234,8 @@ class AutoRegressiveGeneration: # cuda graphs only help if input shapes are constant if ( - device == "cuda" - and hasattr(model, "capture_model") + device == "cuda" + and hasattr(model, "capture_model") and self.model.cache_implementation == "static" and self.model.use_kv_buckets and enough_vram @@ -247,7 +247,7 @@ class AutoRegressiveGeneration: @torch.inference_mode() def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length: int = 1024, min_new_length = 0, top_k: int = 50, top_p: float = 1.0, temperature: float = 1.0, do_sample: bool = False, seed = None, **kwargs): - + if seed is not None: torch_generator = torch.Generator(device = input_ids.device).manual_seed(seed) else: @@ -335,7 +335,7 @@ class AutoRegressiveGeneration: # TODO: have a default self._sample fn and a default check if the model supports autoregGen or not if not hasattr(self.model, "_sample"): raise ValueError("Model doesn't support AutoRegressive Generation!") - + self._prepare_kv_caches() result = self.model._sample( @@ -347,7 +347,7 @@ class AutoRegressiveGeneration: ) return result - + def _prepare_kv_caches(self): for kv_cache in self.kv_caches.values(): kv_cache.reset() @@ -357,13 +357,13 @@ class AutoRegressiveGeneration: return GenerationSampling.BEAM_SAMPLING else: return GenerationSampling.GREEDY_SEARCH - + def _prepare_generated_length( self, generation_config: GenerationConfig, input_ids_length, ): - + """ max_length = user_input_id_tokens + generation_max_length """ if generation_config.max_new_length is not None: @@ -374,11 +374,11 @@ class AutoRegressiveGeneration: generation_config.min_length = generation_config.min_new_length + input_ids_length return generation_config - + def _get_cache( self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs ) -> Cache: - + assert cache_implementation == "static", f"Only 'static' cache is supported, got {cache_implementation}" cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] @@ -412,7 +412,7 @@ class AutoRegressiveGeneration: return self.model._cache - + def _prepare_cache_for_generation( self, generation_config: GenerationConfig, @@ -466,7 +466,7 @@ class AutoRegressiveGeneration: model_kwargs = generation_config.update(**kwargs) return generation_config, model_kwargs - + def _validate_generated_length(self, generation_config: GenerationConfig, input_ids_length): """Performs validation related to the resulting generated length""" @@ -498,7 +498,7 @@ class AutoRegressiveGeneration: f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, UserWarning, ) - + def _expand_inputs_for_generation( self, expand_size: int = 1, @@ -526,13 +526,13 @@ class AutoRegressiveGeneration: model_kwargs = _expand_dict_for_generation(model_kwargs) return input_ids, model_kwargs - + def _prepare_special_tokens( self, generation_config: GenerationConfig, device: Optional[Union[torch.device, str]] = None, ): - + def _tensor_or_none(token, device=None): if token is None: return token @@ -564,7 +564,7 @@ class AutoRegressiveGeneration: generation_config: GenerationConfig, model_kwargs: dict[str, Any], ) -> torch.LongTensor: - + pad_token_id = generation_config._pad_token_tensor eos_token_id = generation_config._eos_token_tensor @@ -593,12 +593,12 @@ class AutoRegressiveGeneration: attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask ) return attention_mask - + def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0, top_k=50, top_p=1.0, temperature=1.0, do_sample = False, seed=None, **kwargs): # to work with BaseModel if hasattr(patcher, "model") and hasattr(patcher.model, "diffusion_model"): model = patcher.model.diffusion_model - + if node._cached_autoregressive_sampler is None or node._cached_autoregressive_sampler.model is not model: if model.device != patcher.load_device: model = model.to(patcher.load_device, dtype=model.dtype) @@ -610,7 +610,7 @@ def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0, kwargs.update({k: v for k, v in input_ids.items() if k != "input_ids"}) else: main_input_ids = input_ids - + device = node._cached_autoregressive_sampler.device main_input_ids = main_input_ids.to(device) diff --git a/comfy/ldm/higgsv2/cuda_graph_runner.py b/comfy/ldm/higgsv2/cuda_graph_runner.py index bc89edbe2..e86f034c7 100644 --- a/comfy/ldm/higgsv2/cuda_graph_runner.py +++ b/comfy/ldm/higgsv2/cuda_graph_runner.py @@ -22,7 +22,7 @@ class CUDAGraphRunner(nn.Module): def capture(self, *args, **kwargs): assert self._graph is None - + for _ in range(_NUM_WARMUP_ITERS): self.model(*args, **kwargs) diff --git a/comfy/ldm/higgsv2/loudness.py b/comfy/ldm/higgsv2/loudness.py index 53a06d201..ac1850c61 100644 --- a/comfy/ldm/higgsv2/loudness.py +++ b/comfy/ldm/higgsv2/loudness.py @@ -125,7 +125,7 @@ class IIRfilter(object): @property def b_and_a(self): return self.generate_coefficients() - + class Meter(torch.nn.Module): def __init__( @@ -227,7 +227,7 @@ class Meter(torch.nn.Module): return unfolded def integrated_loudness(self, data: torch.Tensor): - + if not torch.is_tensor(data): data = torch.from_numpy(data).float() else: @@ -291,10 +291,10 @@ class Meter(torch.nn.Module): def loudness( audio_data, sample_rate: int, target_loudness: int, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs -): +): MIN_LOUDNESS = -70 device = audio_data.device - + original_length = audio_data.shape[-1] signal_duration = original_length / sample_rate diff --git a/comfy/ldm/higgsv2/model.py b/comfy/ldm/higgsv2/model.py index cb79b5df3..16180b657 100644 --- a/comfy/ldm/higgsv2/model.py +++ b/comfy/ldm/higgsv2/model.py @@ -19,8 +19,8 @@ from collections import defaultdict, OrderedDict from typing import Optional, Tuple, Union, List class GenerationMode(Enum): - TEXT = 0 - AUDIO_INIT = 1 + TEXT = 0 + AUDIO_INIT = 1 AUDIO_IN_PROGRESS = 2 def _ignore_causal_mask_sdpa( @@ -413,7 +413,7 @@ class HiggsAudioModel(nn.Module): def __init__(self, device = None, dtype = None, operations = None, **kwargs): super().__init__() - + self.padding_idx = kwargs["pad_token_id"] self.audio_in_token_idx = kwargs["audio_in_token_idx"] self.audio_out_token_idx = kwargs["audio_out_token_idx"] @@ -439,7 +439,7 @@ class HiggsAudioModel(nn.Module): self.audio_out_bos_token_id = 128013 self.audio_eos_token_id = 128012 - + text_config = kwargs["text_config"] llama_config = Llama2Config(num_attention_heads = text_config["num_attention_heads"], num_key_value_heads = text_config["num_key_value_heads"], @@ -616,7 +616,7 @@ class HiggsAudioModel(nn.Module): next_audio_tokens = None return next_tokens, next_audio_tokens - + def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -677,7 +677,7 @@ class HiggsAudioModel(nn.Module): causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)) return causal_mask - + def _embed_audio_ids(self, audio_ids): codebook_shift = ( torch.arange(self.config["audio_num_codebooks"], device=audio_ids.device) * self.audio_codebook_size @@ -712,7 +712,7 @@ class HiggsAudioModel(nn.Module): ) audio_attention_mask = attention_mask.masked_fill(no_audio_out_mask, min_dtype) return fast_forward_attention_mask, audio_attention_mask - + def _forward_core( self, hidden_states: torch.Tensor, @@ -728,7 +728,7 @@ class HiggsAudioModel(nn.Module): is_using_cuda_graph: Optional[bool] = False, ): - position_id_offset = cache_position[0] if use_cache else 0 + position_id_offset = cache_position[0] if use_cache else 0 position_embeddings = self.rotary_emb(hidden_states, position_ids + position_id_offset) for decoder_layer in self.layers: @@ -927,7 +927,7 @@ class HiggsAudioModel(nn.Module): ) return ret - + def _update_model_kwargs_for_generation( self, outputs, @@ -956,13 +956,13 @@ class HiggsAudioModel(nn.Module): ) return model_kwargs - + def _copy_kv_cache(self, from_cache: Cache, to_cache: Cache): from_cache_size = from_cache.get_max_cache_shape() assert to_cache.get_max_cache_shape() >= from_cache_size, ( f"The target cache size {to_cache.get_max_cache_shape()} is smaller than the source cache size {from_cache_size}." ) - + n_layers = self.num_hidden_layers for i in range(n_layers): @@ -977,7 +977,7 @@ class HiggsAudioModel(nn.Module): self.cache_config.head_dim), device=self.device, dtype=self.dtype ) - + if getattr(to_layer, "values", None) is None: to_layer.values = torch.zeros( (self.cache_config.max_batch, self.cache_config.num_key_value_heads, @@ -1011,7 +1011,7 @@ class HiggsAudioModel(nn.Module): f"The current sequence length {current_sequence_length} is larger than " f"all past key values buckets {past_key_values_buckets.keys()}." ) - + def _sample( self, input_ids: torch.LongTensor, @@ -1020,7 +1020,7 @@ class HiggsAudioModel(nn.Module): past_key_values_buckets: Optional[OrderedDict[int, Cache]], **model_kwargs, ): - + # code supports only non-mixed batchs audio_out_bos_token_id = generation_config.generation_kwargs.get("audio_out_bos_token_id", None) @@ -1069,7 +1069,7 @@ class HiggsAudioModel(nn.Module): while not this_peer_finished: eos_token_tensor = torch.tensor([self.config["text_config"]["eos_token_id"]], device=input_ids.device) - + if input_ids[0][-1] == audio_out_bos_token_id: generation_mode = GenerationMode.AUDIO_INIT elif input_ids[0][-1] == self.audio_out_token_idx: @@ -1211,7 +1211,7 @@ class HiggsAudioModel(nn.Module): pbar.update(pbar.total - pbar.current) return audio_sequences - + @torch.inference_mode() def generate( self, @@ -1222,7 +1222,7 @@ class HiggsAudioModel(nn.Module): generation_functions = None, **kwargs, ): - + if generation_config is None: generation_config = GenerationConfig() diff --git a/comfy/ldm/higgsv2/preprocess.py b/comfy/ldm/higgsv2/preprocess.py index 6a8ab4957..219b5b374 100644 --- a/comfy/ldm/higgsv2/preprocess.py +++ b/comfy/ldm/higgsv2/preprocess.py @@ -1,4 +1,4 @@ -from typing import ( +from typing import ( Dict, List, Optional, Union, Tuple, MutableMapping, Any, Mapping, Collection, get_type_hints, get_args, get_origin ) @@ -153,13 +153,13 @@ def from_dict(data_class, data): value = _build_value(type_=field_type, data=data[key]) except Exception as error: raise ValueError(error) - + if not is_instance(value, field_type): raise ValueError(( f'wrong value type for field "{field.name}" - should be "{field_type}" ' f'instead of value "{value}" of type "{type(value)}"' )) - + init_values[field.name] = value instance = data_class(**init_values) @@ -273,7 +273,7 @@ def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer): try: if not isinstance(sample, ChatMLSample): - # replacing pd.isna + # replacing pd.isna def is_nan(x): if isinstance(x, float): return math.isnan(x) @@ -282,7 +282,7 @@ def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer): if isinstance(x, torch.Tensor) and x.numel() == 1: return torch.isnan(x).item() return False - + if "speaker" in sample and is_nan(sample["speaker"]): sample["speaker"] = None if "start_index" in sample and is_nan(sample["start_index"]): @@ -489,7 +489,7 @@ class HiggsAudioSampleCollator: def _process_and_duplicate_audio_tokens( self, input_ids: torch.Tensor, audio_idx: int, wv: torch.Tensor, labels: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, int]: - + total_samples = len(wv) num_chunks = math.ceil(total_samples / self.chunk_size_samples) @@ -583,15 +583,15 @@ class HiggsAudioSampleCollator: audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0) if len(audio_in_ids_l) > 0: - + # I tried to remove the for-loop in original implementation # but to do batching with padding caused problem so I turned it into a list compre. - lengths = [seg.shape[1] for seg in audio_in_ids_l] - aug_lengths = [l + 2 for l in lengths] + lengths = [seg.shape[1] for seg in audio_in_ids_l] + aug_lengths = [l + 2 for l in lengths] audio_in_ids_start = torch.cumsum( torch.tensor([0] + aug_lengths[:-1], dtype=torch.long), dim=0 ) - + if self.disable_audio_codes_transform: audio_in_ids = torch.cat(audio_in_ids_l, dim=1).long() else: @@ -607,7 +607,7 @@ class HiggsAudioSampleCollator: if self.use_delay_pattern: with_tokens = [ build_delay_pattern_mask( - tok.unsqueeze(0), + tok.unsqueeze(0), bos_token_id=self.audio_stream_bos_id, pad_token_id=self.audio_stream_eos_id )[0] diff --git a/comfy/ldm/higgsv2/tokenizer.py b/comfy/ldm/higgsv2/tokenizer.py index c4277291c..d467d23ae 100644 --- a/comfy/ldm/higgsv2/tokenizer.py +++ b/comfy/ldm/higgsv2/tokenizer.py @@ -160,7 +160,7 @@ class DACDecoder(nn.Module): def forward(self, x): return self.model(x) - + class Conv1d1x1: def __new__(cls, in_channels, out_channels, bias=True, device=None, dtype=None, operations=None): operations = operations or nn @@ -168,7 +168,7 @@ class Conv1d1x1: in_channels, out_channels, kernel_size=1, bias=bias, device=device, dtype=dtype ) - + class Conv1d(nn.Module): def __init__( self, @@ -203,7 +203,7 @@ class Conv1d(nn.Module): def forward(self, x): x = self.conv(x) return x - + class ConvTranspose1d(nn.Module): def __init__( self, @@ -237,7 +237,7 @@ class ConvTranspose1d(nn.Module): def forward(self, x): x = self.deconv(x) return x - + class ResidualUnit(nn.Module): def __init__( self, @@ -283,7 +283,7 @@ class EncoderBlock(nn.Module): self.conv = Conv1d( in_channels=in_channels, out_channels=out_channels, - kernel_size = kernel_size, + kernel_size = kernel_size, stride=stride, bias=bias, device = device, dtype = dtype, operations = operations @@ -435,7 +435,7 @@ class Decoder(nn.Module): x = self.conv_blocks[i](x) x = self.conv2(x) return x - + class HiggsAudioFeatureExtractor(nn.Module): def __init__(self, sampling_rate=16000): super().__init__() @@ -493,7 +493,7 @@ class EuclideanCodebook(nn.Module): embed = self.embed.t() if x.dtype != embed.dtype: x = x.to(embed.dtype) - + dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True)) embed_ind = dist.max(dim=-1).indices return embed_ind @@ -520,14 +520,14 @@ class EuclideanCodebook(nn.Module): return quantize def forward(self, x): - orig_shape = x.shape # [B, T, D] + orig_shape = x.shape # [B, T, D] flat = x.view(-1, x.shape[-1]) # [B*T, D] - embed_ind = self.quantize(flat) - embed_ind = self.postprocess_emb(embed_ind, orig_shape) + embed_ind = self.quantize(flat) + embed_ind = self.postprocess_emb(embed_ind, orig_shape) # now embed_ind has shape [B, T] - quantize = self.dequantize(embed_ind) + quantize = self.dequantize(embed_ind) # quantize: [B, T, D] return quantize, embed_ind @@ -636,9 +636,9 @@ class ResidualVectorQuantization(nn.Module): quantized = F.embedding(embed_id, codebook_weight).transpose(1, 2) # (B, D, T) quantized = F.linear(quantized.transpose(1, 2), proj_weight, proj_biases).transpose(1, 2) return quantized - + codebook_weights = torch.stack([q._codebook.embed for q in self.layers]) # (n_codebooks, vocab_size, D) - proj_weights = torch.stack([q.project_out.weight for q in self.layers]) + proj_weights = torch.stack([q.project_out.weight for q in self.layers]) quantized = vmap(decode_one)(codebook_weights, proj_weights, q_indices, biases) @@ -705,7 +705,7 @@ class ResidualVectorQuantizer(nn.Module): """Decode the given codes to the quantized representation.""" quantized = self.vq.decode(codes) return quantized - + class HiggsAudioTokenizer(nn.Module): def __init__( self, @@ -786,7 +786,7 @@ class HiggsAudioTokenizer(nn.Module): x = x[:, 0, :] x = F.pad(x, (160, 160)) target = self.semantic_model(x, output_hidden_states=True).hidden_states - target = torch.stack(target, dim=1) + target = torch.stack(target, dim=1) target = target.mean(1) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 079aae87b..52f9995fe 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -413,7 +413,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["qkv_bias"] = False dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys return dit_config - + if "{}layers.27.audio_post_attention_layernorm.weight".format(key_prefix) in state_dict_keys: autoregressive_config = {} diff --git a/comfy/supported_models.py b/comfy/supported_models.py index b403642f4..25e798f4b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1307,8 +1307,8 @@ class Higgsv2(supported_models_base.BASE): def get_model(self, state_dict, prefix="", device=None): out = model_base.Higgsv2(self, device=device) - return out - + return out + def clip_target(self, state_dict = {}): return supported_models_base.ClipTarget(comfy.text_encoders.higgsv2.DummyTokenizer, comfy.text_encoders.higgsv2.HiggsTokenizer) diff --git a/comfy/text_encoders/higgsv2.py b/comfy/text_encoders/higgsv2.py index 4e187b9fc..b861ac0e5 100644 --- a/comfy/text_encoders/higgsv2.py +++ b/comfy/text_encoders/higgsv2.py @@ -13,7 +13,7 @@ class DummyTokenizer: def revert_delay_pattern_vectorized(data: torch.Tensor) -> torch.Tensor: num_codebooks, total_len = data.shape seq_len = total_len - num_codebooks + 1 - + col_idx = torch.arange(seq_len, device=data.device)[None, :] \ + torch.arange(num_codebooks, device=data.device)[:, None] out = data[torch.arange(num_codebooks)[:, None], col_idx] @@ -27,7 +27,7 @@ class HiggsTokenizer(nn.Module): self.device = device self.dtypes = [torch.float32] - here = os.path.dirname(__file__) + here = os.path.dirname(__file__) tokenizer_path = os.path.join(here, "higgs_text_tokenizer") self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) @@ -65,16 +65,16 @@ class HiggsTokenizer(nn.Module): # due to instability issues, I had to convert the audio tokenizer to float32, avoiding outputing nans self.audio_tokenizer = self.audio_tokenizer.to(self.dtype) torch.cuda.synchronize() - + for audio in audio_tokens: vq_code = revert_delay_pattern_vectorized(audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1] wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0] outputs.append(wv_numpy) return (None, {"waveform": torch.stack(outputs, dim = 0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only - + def load_state_dict(self, sd, strict = False): return self.audio_tokenizer.load_state_dict(sd, strict = strict) - + def state_dict(self): return self.audio_tokenizer.state_dict() diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 19aa5118a..76de4736e 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -244,10 +244,10 @@ class Attention(nn.Module): output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) out = self.o_proj(output) - + if past_key_value is not None: return out, past_key_value - + return out class MLP(nn.Module): diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 6bedf543b..8651b6a57 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -9,19 +9,17 @@ import folder_paths import os import io import json -import base64 import random import hashlib import numpy as np import node_helpers -from io import BytesIO from comfy.cli_args import args from comfy.comfy_types import IO from comfy.comfy_types import FileLocator from dataclasses import asdict from comfy.ldm.higgsv2.loudness import loudness from comfy.ldm.higgsv2.preprocess import ( - prepare_chatml_sample, Message, ChatMLSample, ChatMLDatasetSample, AudioContent, TextContent, transcript_normalize + prepare_chatml_sample, Message, ChatMLSample, ChatMLDatasetSample, AudioContent, transcript_normalize ) AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>" @@ -29,7 +27,7 @@ AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>" MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech. If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice. If no speaker tag is present, select a suitable voice on your own.""" - + class LoudnessNormalization: CATEGORY = "audio" @@ -42,7 +40,7 @@ class LoudnessNormalization: "block_size": ("FLOAT", {"default": 0.400, "min": 0.1, "max": 1.0, "step": 0.05}), "loudness_threshold": ("FLOAT", {"default": -23.0, "min": -70.0, "max": 0.0, "step": 0.5, "tooltip": "Target loudness in LUFS. Common values are -23.0 (broadcast), -14.0 (streaming)."})}} - + def normalize(self, audio, loudness_threshold, block_size): sampling_rate = audio["sample_rate"] waveform = audio["waveform"] @@ -75,10 +73,10 @@ def prepare_chatml_input( if audio_content.raw_audio.device != next(clip.audio_tokenizer.parameters()).device: audio_content.raw_audio = audio_content.raw_audio.to(next(clip.audio_tokenizer.parameters()).device) - + audio_ids = clip.audio_tokenizer.encode(audio_content.raw_audio, sampling_rate) audio_ids_l.append(audio_ids.squeeze(0)) - + if len(audio_ids_l) > 0: audio_ids_start = torch.tensor( np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])), @@ -115,18 +113,18 @@ def prepare_chatml_input( def postprocess_chatml(text: str) -> str: speakers = set(re.findall(r'\[SPEAKER\d+\]', text)) skip_recon = True - + if len(speakers) > 1: parts = text.split('<|eot_id|>') - + # keep the first <|eot_id|> and the last one first_eot = parts[0] + '<|eot_id|>' - middle_parts = ''.join(parts[1:-1]) + middle_parts = ''.join(parts[1:-1]) last_eot = '<|eot_id|>' + parts[-1] - + text = first_eot + middle_parts + last_eot skip_recon = False - + return text, skip_recon class CreateChatMLSample: @@ -165,30 +163,30 @@ class CreateChatMLSample: if audio is not None: clip.load_model() - + if hasattr(clip, "cond_stage_model"): clip = clip.cond_stage_model text = transcript_normalize(text) - + messages = [] lines = text.splitlines() sampling_rate = False current_role = None collecting_system = False system_buffer = [] - + for line in lines: line = line.strip() if not line: continue - + # system start if line.lower().startswith("system:"): collecting_system = True system_buffer.append(line[len("system:"):].strip()) continue - + # while collecting system prompt if collecting_system: system_buffer.append(line) @@ -198,7 +196,7 @@ class CreateChatMLSample: system_buffer = [] collecting_system = False continue - + # speaker lines SPEAKER-0: text match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE) if match: @@ -245,12 +243,12 @@ class CreateChatMLSample: chat_ml_sample, clip.tokenizer, ) - + if audio is None: audio_contents = None out = prepare_chatml_input(clip, input_tokens, audio_contents, sampling_rate = sampling_rate) return (out,) - + class EmptyLatentAudio: def __init__(self): self.device = comfy.model_management.intermediate_device() @@ -612,7 +610,7 @@ NODE_CLASS_MAPPINGS = { "PreviewAudio": PreviewAudio, "ConditioningStableAudio": ConditioningStableAudio, "LoudnessNormalization": LoudnessNormalization, - "CreateChatMLSample": CreateChatMLSample + "CreateChatMLSample": CreateChatMLSample, "RecordAudio": RecordAudio, } @@ -626,6 +624,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SaveAudioMP3": "Save Audio (MP3)", "SaveAudioOpus": "Save Audio (Opus)", "LoudnessNormalization": "Loudness Normalization", - "CreateChatMLSample": "Create ChatML Sample" + "CreateChatMLSample": "Create ChatML Sample", "RecordAudio": "Record Audio", } diff --git a/nodes.py b/nodes.py index 95eb838b8..024c31ac6 100644 --- a/nodes.py +++ b/nodes.py @@ -1571,7 +1571,7 @@ class AutoRegressiveGeneration: "do_sample": ("BOOLEAN", {"default": False, "tooltip": "Add randomness in decoding the tokens."}), } } - + RETURN_TYPES = ("TOKENS",) FUNCTION = "generate" @@ -1582,7 +1582,7 @@ class AutoRegressiveGeneration: def generate(self, model, input_ids, seed, max_new_length, min_new_length, top_k, top_p, temperature, do_sample): return (auto_sample(self, model, input_ids, max_new_length, min_new_length, top_k, top_p, temperature, do_sample, seed = seed),) - + class DecodeTokens: @classmethod def INPUT_TYPES(s): @@ -1591,7 +1591,7 @@ class DecodeTokens: "clip": (IO.CLIP, {"tooltip": "The model used for generation."}), "tokens": ("TOKENS", ),} } - + FUNCTION = "decode" CATEGORY = "conditioning" RETURN_TYPES = ("TEXT", "AUDIO")