ComfyUI/comfy_extras/nodes_audio.py
Yousef Rafat 2ac8999287 final
2025-09-09 23:04:03 +03:00

644 lines
23 KiB
Python

from __future__ import annotations
import av
import re
import torchaudio
import torch
import comfy.model_management
import folder_paths
import os
import io
import json
import random
import hashlib
import numpy as np
import node_helpers
from comfy.cli_args import args
from comfy.comfy_types import IO
from comfy.comfy_types import FileLocator
from dataclasses import asdict
from comfy.ldm.higgsv2.loudness import loudness
from comfy.ldm.higgsv2.preprocess import (
prepare_chatml_sample, Message, ChatMLSample, ChatMLDatasetSample, AudioContent, transcript_normalize
)
MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech.
If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.
If no speaker tag is present, select a suitable voice on your own."""
class LoudnessNormalization:
CATEGORY = "audio"
RETURN_TYPES = ("AUDIO",)
FUNCTION = "normalize"
@classmethod
def INPUT_TYPES(s):
return {"required": {"audio": ("AUDIO", ),
"block_size": ("FLOAT", {"default": 0.400, "min": 0.1, "max": 1.0, "step": 0.05}),
"loudness_threshold": ("FLOAT", {"default": -23.0, "min": -70.0, "max": 0.0, "step": 0.5,
"tooltip": "Target loudness in LUFS. Common values are -23.0 (broadcast), -14.0 (streaming)."})}}
def normalize(self, audio, loudness_threshold, block_size):
sampling_rate = audio["sample_rate"]
waveform = audio["waveform"]
return {"waveform": loudness(waveform, sampling_rate, target_loudness = loudness_threshold, block_size = block_size), "sample_rate": sampling_rate}
def prepare_chatml_input(
clip,
input_tokens,
audio_contents,
sampling_rate,
postfix_str: str = "",
):
if hasattr(clip, "postfix"):
postfix_str = clip.postfix
if postfix_str:
postfix = clip.tokenizer.encode(postfix_str, add_special_tokens=False)
input_tokens.extend(postfix)
audio_ids_l = []
if audio_contents is not None:
if not hasattr(clip, "audio_tokenizer"):
raise ValueError("This model does not have an audio tokenizer. The chosen model may not support ChatML Format")
for audio_content in audio_contents:
audio_content.raw_audio = audio_content.raw_audio.squeeze(0)
if audio_content.raw_audio.shape[0] == 2:
audio_content.raw_audio = audio_content.raw_audio.mean(dim = 0, keepdim = True)
if audio_content.raw_audio.device != next(clip.audio_tokenizer.parameters()).device:
audio_content.raw_audio = audio_content.raw_audio.to(next(clip.audio_tokenizer.parameters()).device)
audio_ids = clip.audio_tokenizer.encode(audio_content.raw_audio, sampling_rate)
audio_ids_l.append(audio_ids.squeeze(0))
if len(audio_ids_l) > 0:
audio_ids_start = torch.tensor(
np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
dtype=torch.long,
device=audio_contents[0].raw_audio.device,
).to("cpu")[0:-1]
audio_ids_concat = torch.cat(audio_ids_l, dim=1).to("cpu")
else:
audio_ids_start = None
audio_ids_concat = None
sample = ChatMLDatasetSample(
input_ids=torch.LongTensor(input_tokens),
label_ids=None,
audio_ids_concat=audio_ids_concat,
audio_ids_start=audio_ids_start,
audio_waveforms_concat=None,
audio_waveforms_start=None,
audio_sample_rate=None,
audio_speaker_indices=None,
)
if hasattr(clip, "collator"):
sample.input_ids = sample.input_ids.cpu()
sample = clip.collator([sample])
inputs = asdict(sample)
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(clip.device)
return inputs
def postprocess_chatml(text: str) -> str:
speakers = set(re.findall(r'\[SPEAKER\d+\]', text))
skip_recon = True
if len(speakers) > 1:
parts = text.split('<|eot_id|>')
# keep the first <|eot_id|> and the last one
first_eot = parts[0] + '<|eot_id|>'
middle_parts = ''.join(parts[1:-1])
last_eot = '<|eot_id|>' + parts[-1]
text = first_eot + middle_parts + last_eot
skip_recon = False
return text, skip_recon
class CreateChatMLSample:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"text": (IO.STRING, {
"default": "SYSTEM: " + MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE + "\n\n<|scene_desc_start|>\nSPEAKER0:masculine\nSPEAKER1:feminine\n<|scene_desc_end|>",
"multiline": True,
"dynamicPrompts": True,
"tooltip": (
"The conversations to be encoded. "
"To register a conversation start with SPEAKER-0: some text. "
"To add a system prompt start with system:"
),
}),
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for tokenizing the text."}),
},
"optional": {
"audio": (IO.AUDIO, {
"tooltip": "An audio clip to be inserted into the conversation. To register add [audio]",
})
}
}
RETURN_TYPES = ("TOKENS",)
OUTPUT_TOOLTIPS = ("Turns text and audio into a ChatML Format.",)
FUNCTION = "convert_to_ml_format"
CATEGORY = "conditioning"
def convert_to_ml_format(self, clip, text, audio=None):
if audio is not None:
clip.load_model()
if hasattr(clip, "cond_stage_model"):
clip = clip.cond_stage_model
text = transcript_normalize(text)
messages = []
lines = text.splitlines()
sampling_rate = False
current_role = None
collecting_system = False
system_buffer = []
for line in lines:
line = line.strip()
if not line:
continue
# system start
if line.lower().startswith("system:"):
collecting_system = True
system_buffer.append(line[len("system:"):].strip())
continue
# while collecting system prompt
if collecting_system:
system_buffer.append(line)
if "<|scene_desc_end|>" in line or "SPEAKER-" in line:
system_prompt = "\n".join(system_buffer)# + "\n<|scene_desc_end|>"
messages.append(Message(role="system", content=system_prompt))
system_buffer = []
collecting_system = False
continue
# speaker lines SPEAKER-0: text
match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE)
if match:
speaker_id = match.group(1)
content = match.group(2)
current_role = f"[SPEAKER{speaker_id}] "
messages.append(Message(role = "user", content = current_role + content.strip()))
else:
# continuation line goes to last speaker or instruction
if current_role is not None and messages:
messages[-1].content += "\n" + line
# return normal input_ids
if not (len(messages) >= 1):
return (clip.tokenizer(text),)
all_text = "".join(msg.content for msg in messages if msg.role == "user")
# postprocess to allow multi-user speech
all_text, skip_recon = postprocess_chatml(all_text)
if not skip_recon:
lines = all_text.splitlines()
messages = [messages[0]] if messages[0].role == "system" else []
current_role = None
for line in lines:
line = line.strip()
if not line:
continue
match = re.match(r'(\[SPEAKER\d+\])\s*(.*)', line)
if match:
current_role = match.group(1)
content = match.group(2).strip() # only take the text after the tag
messages.append(Message(role="user", content=f"{current_role} {content}" if content else current_role))
else:
if current_role and messages:
messages[-1].content += "\n" + line
# dedepulicate the messages
for idx, m in enumerate(messages):
double_eot = "<|eot_id|><|eot_id|>"
if double_eot in m.content:
cut_index = m.content.index(double_eot)
messages[idx].content = m.content[:cut_index + (len(double_eot) // 2)]
break
if audio is not None:
# for audio cloning, the first message is a transcript, second is the audio,
# third is the request of what the model should say
waveform = audio["waveform"]
sampling_rate = audio["sample_rate"]
messages.insert(1, Message(
role = "assistant",
content = AudioContent(raw_audio = waveform, audio_url = "placeholder")
))
chat_ml_sample = ChatMLSample(messages)
input_tokens, audio_contents, _ = prepare_chatml_sample(
chat_ml_sample,
clip.tokenizer,
)
if audio is None:
audio_contents = None
out = prepare_chatml_input(clip, input_tokens, audio_contents, sampling_rate = sampling_rate)
return (out,)
class EmptyLatentAudio:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod
def INPUT_TYPES(s):
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/audio"
def generate(self, seconds, batch_size):
length = round((seconds * 44100 / 2048) / 2) * 2
latent = torch.zeros([batch_size, 64, length], device=self.device)
return ({"samples":latent, "type": "audio"}, )
class ConditioningStableAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
"seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, positive, negative, seconds_start, seconds_total):
positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
return (positive, negative)
class VAEEncodeAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "latent/audio"
def encode(self, vae, audio):
sample_rate = audio["sample_rate"]
if 44100 != sample_rate:
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
else:
waveform = audio["waveform"]
t = vae.encode(waveform.movedim(1, -1))
return ({"samples":t}, )
class VAEDecodeAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "decode"
CATEGORY = "latent/audio"
def decode(self, vae, samples):
audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
return ({"waveform": audio, "sample_rate": 44100}, )
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results: list[FileLocator] = []
# Prepare metadata dictionary
metadata = {}
if not args.disable_metadata:
if prompt is not None:
metadata["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
# Opus supported sample rates
OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
output_path = os.path.join(full_output_folder, file)
# Use original sample rate initially
sample_rate = audio["sample_rate"]
# Handle Opus sample rate requirements
if format == "opus":
if sample_rate > 48000:
sample_rate = 48000
elif sample_rate not in OPUS_RATES:
# Find the next highest supported rate
for rate in sorted(OPUS_RATES):
if rate > sample_rate:
sample_rate = rate
break
if sample_rate not in OPUS_RATES: # Fallback if still not supported
sample_rate = 48000
# Resample if necessary
if sample_rate != audio["sample_rate"]:
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
# Create output with specified format
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format=format)
# Set metadata on the container
for key, value in metadata.items():
output_container.metadata[key] = value
# Set up the output stream with appropriate properties
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
out_stream.bit_rate = 96000
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "192k":
out_stream.bit_rate = 192000
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
if quality == "V0":
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "320k":
out_stream.bit_rate = 320000
else: #format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
frame.sample_rate = sample_rate
frame.pts = 0
output_container.mux(out_stream.encode(frame))
# Flush encoder
output_container.mux(out_stream.encode(None))
# Close containers
output_container.close()
# Write the output to file
output_buffer.seek(0)
with open(output_path, 'wb') as f:
f.write(output_buffer.getbuffer())
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "audio": results } }
class SaveAudio:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_flac"
OUTPUT_NODE = True
CATEGORY = "audio"
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
class SaveAudioMP3:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_mp3"
OUTPUT_NODE = True
CATEGORY = "audio"
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
class SaveAudioOpus:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_opus"
OUTPUT_NODE = True
CATEGORY = "audio"
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
class PreviewAudio(SaveAudio):
def __init__(self):
self.output_dir = folder_paths.get_temp_directory()
self.type = "temp"
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
@classmethod
def INPUT_TYPES(s):
return {"required":
{"audio": ("AUDIO", ), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format."""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / (2 ** 15)
elif wav.dtype == torch.int32:
return wav.float() / (2 ** 31)
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def load(filepath: str) -> tuple[torch.Tensor, int]:
with av.open(filepath) as af:
if not af.streams.audio:
raise ValueError("No audio stream found in the file.")
stream = af.streams.audio[0]
sr = stream.codec_context.sample_rate
n_channels = stream.channels
frames = []
length = 0
for frame in af.decode(streams=stream.index):
buf = torch.from_numpy(frame.to_ndarray())
if buf.shape[0] != n_channels:
buf = buf.view(-1, n_channels).t()
frames.append(buf)
length += buf.shape[1]
if not frames:
raise ValueError("No audio frames decoded.")
wav = torch.cat(frames, dim=1)
wav = f32_pcm(wav)
return wav, sr
class LoadAudio:
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
CATEGORY = "audio"
RETURN_TYPES = ("AUDIO", )
FUNCTION = "load"
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )
@classmethod
def IS_CHANGED(s, audio):
image_path = folder_paths.get_annotated_filepath(audio)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, audio):
if not folder_paths.exists_annotated_filepath(audio):
return "Invalid audio file: {}".format(audio)
return True
class RecordAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {"audio": ("AUDIO_RECORD", {})}}
CATEGORY = "audio"
RETURN_TYPES = ("AUDIO", )
FUNCTION = "load"
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )
NODE_CLASS_MAPPINGS = {
"EmptyLatentAudio": EmptyLatentAudio,
"VAEEncodeAudio": VAEEncodeAudio,
"VAEDecodeAudio": VAEDecodeAudio,
"SaveAudio": SaveAudio,
"SaveAudioMP3": SaveAudioMP3,
"SaveAudioOpus": SaveAudioOpus,
"LoadAudio": LoadAudio,
"PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio,
"LoudnessNormalization": LoudnessNormalization,
"CreateChatMLSample": CreateChatMLSample,
"RecordAudio": RecordAudio,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"EmptyLatentAudio": "Empty Latent Audio",
"VAEEncodeAudio": "VAE Encode Audio",
"VAEDecodeAudio": "VAE Decode Audio",
"PreviewAudio": "Preview Audio",
"LoadAudio": "Load Audio",
"SaveAudio": "Save Audio (FLAC)",
"SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)",
"LoudnessNormalization": "Loudness Normalization",
"CreateChatMLSample": "Create ChatML Sample",
"RecordAudio": "Record Audio",
}