bug fixes + added some features

This commit is contained in:
Yousef Rafat 2025-09-09 01:07:36 +03:00
parent 233e4415a1
commit 64124224e7
5 changed files with 126 additions and 26 deletions

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import math
import copy
import torch
import inspect
@ -56,9 +57,9 @@ class TopKLogits:
class TemperatureLogitsWarper:
def __init__(self, temperature: float):
if not (temperature > 0):
raise ValueError(f"`temperature` (={temperature}) must be positive temperature > 0")
raise ValueError(f"`temperature` (={temperature}) must be a positive number > 0")
self.temperature = temperature
def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
@ -86,10 +87,30 @@ class TopPLogitsWarper:
scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
return scores_processed
class MinLengthLogitsProcessor:
def __init__(self, min_length: int, eos_token_id: torch.Tensor):
self.min_length = min_length
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids is None:
return scores
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
scores_processed = scores.clone()
if input_ids.shape[-1] < self.min_length:
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
return scores_processed
def get_logits_processing(config: GenerationConfig):
# TODO: add support for beam search with diversity penalty
logits_processors = []
if config._eos_token_tensor is not None and config.min_length > 1:
logits_processors.append(MinLengthLogitsProcessor(config.min_length, config._eos_token_tensor))
if config.top_k is not None and config.top_k != 0:
logits_processors.append(TopKLogits(config.top_k))
@ -101,28 +122,59 @@ def get_logits_processing(config: GenerationConfig):
return logits_processors
def apply_logits_processing(logits, logits_processing_list, **kwargs):
def apply_logits_processing(input_ids, logits, logits_processing_list, **kwargs):
for process in logits_processing_list:
func_args = inspect.signature(process.__call__).parameters
if not all(arg in kwargs for arg in list(func_args.keys())[1:]):
if not all(arg in kwargs for arg in list(func_args.keys())[3:]):
raise ValueError(
f"Make sure that all the required parameters: {list(func_args.keys())} for "
f"{process.__class__} are passed to the logits processor."
)
logits = process(logits, **kwargs)
if "input_ids" in func_args:
logits = process(input_ids, logits)
else:
logits = process(logits, **kwargs)
return logits
def check_stopping_criteria(input_ids: torch.Tensor, max_length: int, eos_token):
def check_stopping_strings(input_ids: torch.Tensor, stop_strings: list) -> torch.BoolTensor:
# stop_strings must be a list of lists: List[List[], List[]]
device = input_ids.device
batch_size, seq_len = input_ids.shape
finished = torch.zeros(batch_size, dtype = torch.bool, device = device)
for b in range(batch_size):
row = input_ids[b]
# check each stop token sequence
for stop_ids in stop_strings:
n = len(stop_ids)
if n == 0 or n > seq_len:
continue
# compare tail of the generated ids to the stop sequence
if torch.all(row[-n:] == torch.tensor(stop_ids, device = device, dtype = row.dtype)):
finished[b] = True
break
return finished
def check_stopping_criteria(input_ids: torch.Tensor, max_length: int, eos_token, stop_strings: tuple = None):
device = input_ids.device
if not isinstance(eos_token, torch.Tensor):
eos_token = torch.tensor(eos_token, device=input_ids.device)
eos_token = torch.tensor(eos_token, device = device)
max_len_done = input_ids.shape[1] >= max_length
eos_done = torch.isin(input_ids[:, -1], eos_token)
# finished either by lenght or eos
finished_mask = max_len_done | eos_done
if stop_strings is not None:
stop_done = check_stopping_strings(input_ids, stop_strings)
else:
stop_done = torch.zeros(input_ids.size(0), dtype=torch.bool, device=device)
# finished either by lenght or eos or stop strings
finished_mask = max_len_done | eos_done | stop_done
unfinished_mask = ~finished_mask

View File

@ -500,6 +500,8 @@ class HiggsAudioModel(nn.Module):
torch.ones(kwargs["audio_num_codebooks"]) / kwargs["audio_num_codebooks"]
)
self.stop_strings = [[128009], [128001]]
def _sample_audio_tokens(
self,
audio_logits: torch.Tensor,
@ -520,7 +522,7 @@ class HiggsAudioModel(nn.Module):
audio_eos_token_id = generation_config.generation_kwargs.get("audio_eos_token_id", None)
next_audio_token_logits = audio_logits.clone()[-1, :, :].float().to(device)
next_audio_token_scores = apply_logits_processing(next_audio_token_logits, logits_processing_list)
next_audio_token_scores = apply_logits_processing(None, next_audio_token_logits, logits_processing_list)
if do_sample:
probs = nn.functional.softmax(next_audio_token_scores, dim = -1)
@ -588,6 +590,9 @@ class HiggsAudioModel(nn.Module):
logits_processing_list,
device: torch.device,
generation_mode: GenerationMode,
torch_generator,
is_using_cuda_graphs,
do_sample = False,
) -> torch.Tensor:
"""Sample text tokens from the logits"""
@ -595,7 +600,7 @@ class HiggsAudioModel(nn.Module):
next_token_logits = next_token_logits.to(input_ids.device)
# pre-process distribution
next_token_scores = apply_logits_processing(next_token_logits, logits_processing_list)
next_token_scores = apply_logits_processing(input_ids, next_token_logits, logits_processing_list)
if generation_mode == GenerationMode.AUDIO_INIT:
# See the audio bos token, we should start generating audio tokens
@ -612,7 +617,17 @@ class HiggsAudioModel(nn.Module):
device=device,
)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
if do_sample:
probs = nn.functional.softmax(next_token_scores, dim = -1)
# same as for audio
if not is_using_cuda_graphs:
next_tokens = torch.multinomial(probs, num_samples = 1, generator=torch_generator).squeeze(1)
else:
next_tokens = categorical_sample(probs, generator = torch_generator)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
next_audio_tokens = None
return next_tokens, next_audio_tokens
@ -1093,12 +1108,6 @@ class HiggsAudioModel(nn.Module):
del model_inputs["audio_out_ids_start"]
if generation_config.use_cache:
if "audio_features" in model_inputs and model_inputs["audio_features"] is not None:
model_inputs["audio_features"] = model_inputs["audio_features"][:0, ...]
model_inputs["audio_feature_attention_mask"] = model_inputs["audio_feature_attention_mask"][
:0, ...
]
if "audio_in_ids" in model_inputs and model_inputs["audio_in_ids"] is not None:
model_inputs["audio_in_ids"] = None
model_inputs["audio_in_ids_start"] = None
@ -1159,6 +1168,9 @@ class HiggsAudioModel(nn.Module):
logits_processing_list=logits_processing_list,
device=input_ids.device,
generation_mode=generation_mode,
torch_generator = torch_generator,
do_sample = do_sample,
is_using_cuda_graphs = is_using_cuda_graphs
)
if next_audio_tokens is not None:
@ -1199,7 +1211,7 @@ class HiggsAudioModel(nn.Module):
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
input_ids_full = torch.cat([input_ids_full, next_tokens[:, None]], dim=-1)
finished, unfinished_sequences = check_stopping_criteria(input_ids_full, max_length, eos_token = eos_token_tensor)
finished, unfinished_sequences = check_stopping_criteria(input_ids_full, max_length, eos_token = eos_token_tensor, stop_strings = self.stop_strings)
this_peer_finished = finished.all()
cur_len += 1

View File

@ -71,7 +71,8 @@ class HiggsTokenizer(nn.Module):
wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
outputs.append(wv_numpy)
return (None, {"waveform": torch.stack(outputs, dim = 0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only
# currently only supports one batch size
return (None, {"waveform": torch.cat(outputs, dim = 0).unsqueeze(0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only
def load_state_dict(self, sd, strict = False):
return self.audio_tokenizer.load_state_dict(sd, strict = strict)

View File

@ -22,7 +22,6 @@ from comfy.ldm.higgsv2.preprocess import (
prepare_chatml_sample, Message, ChatMLSample, ChatMLDatasetSample, AudioContent, transcript_normalize
)
AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>"
MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech.
If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.
@ -175,6 +174,8 @@ class CreateChatMLSample:
current_role = None
collecting_system = False
system_buffer = []
collecting_instruction = False
instruction_buffer = []
for line in lines:
line = line.strip()
@ -197,6 +198,26 @@ class CreateChatMLSample:
collecting_system = False
continue
# generation instruction start
if "<|generation_instruction_start|>" in line:
collecting_instruction = True
instruction_buffer = []
continue
if collecting_instruction:
if "<|generation_instruction_end|>" in line:
instruction_text = "\n".join(instruction_buffer)
# include both start and end tokens
messages.append(Message(
role="user",
content=f"<|generation_instruction_start|>\n{instruction_text}\n<|generation_instruction_end|>"
))
instruction_buffer = []
collecting_instruction = False
else:
instruction_buffer.append(line)
continue
# speaker lines SPEAKER-0: text
match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE)
if match:
@ -221,14 +242,29 @@ class CreateChatMLSample:
lines = all_text.splitlines()
messages = [messages[0]] if messages[0].role == "system" else []
current_role = None
for line in lines:
match = re.match(r'\[SPEAKER\d+\]', line)
line = line.strip()
if not line:
continue
match = re.match(r'(\[SPEAKER\d+\])\s*(.*)', line)
if match:
current_role = match.group(0)
messages.append(Message(role="user", content=line.strip()))
current_role = match.group(1)
content = match.group(2).strip() # only take the text after the tag
messages.append(Message(role="user", content=f"{current_role} {content}" if content else current_role))
else:
if current_role and messages:
messages[-1].content += "\n" + line.strip()
messages[-1].content += "\n" + line
# dedepulicate the messages
for idx, m in enumerate(messages):
double_eot = "<|eot_id|><|eot_id|>"
if double_eot in m.content:
cut_index = m.content.index(double_eot)
messages[idx].content = m.content[:cut_index + (len(double_eot) // 2)]
break
if audio is not None:
# for audio cloning, the first message is a transcript, second is the audio,
# third is the request of what the model should say

View File

@ -2065,7 +2065,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"KSampler": "KSampler",
"KSamplerAdvanced": "KSampler (Advanced)",
"AutoRegressiveGeneration": "Autoregressive Generation",
""
# Loaders
"CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
"CheckpointLoaderSimple": "Load Checkpoint",