mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 16:50:57 +08:00
bug fixes + added some features
This commit is contained in:
parent
233e4415a1
commit
64124224e7
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import copy
|
||||
import torch
|
||||
import inspect
|
||||
@ -56,9 +57,9 @@ class TopKLogits:
|
||||
|
||||
class TemperatureLogitsWarper:
|
||||
def __init__(self, temperature: float):
|
||||
|
||||
if not (temperature > 0):
|
||||
raise ValueError(f"`temperature` (={temperature}) must be positive temperature > 0")
|
||||
raise ValueError(f"`temperature` (={temperature}) must be a positive number > 0")
|
||||
|
||||
self.temperature = temperature
|
||||
|
||||
def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
@ -86,10 +87,30 @@ class TopPLogitsWarper:
|
||||
scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores_processed
|
||||
|
||||
class MinLengthLogitsProcessor:
|
||||
def __init__(self, min_length: int, eos_token_id: torch.Tensor):
|
||||
self.min_length = min_length
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
if input_ids is None:
|
||||
return scores
|
||||
|
||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
|
||||
scores_processed = scores.clone()
|
||||
if input_ids.shape[-1] < self.min_length:
|
||||
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
|
||||
return scores_processed
|
||||
|
||||
def get_logits_processing(config: GenerationConfig):
|
||||
# TODO: add support for beam search with diversity penalty
|
||||
logits_processors = []
|
||||
|
||||
if config._eos_token_tensor is not None and config.min_length > 1:
|
||||
logits_processors.append(MinLengthLogitsProcessor(config.min_length, config._eos_token_tensor))
|
||||
|
||||
if config.top_k is not None and config.top_k != 0:
|
||||
logits_processors.append(TopKLogits(config.top_k))
|
||||
|
||||
@ -101,28 +122,59 @@ def get_logits_processing(config: GenerationConfig):
|
||||
|
||||
return logits_processors
|
||||
|
||||
def apply_logits_processing(logits, logits_processing_list, **kwargs):
|
||||
def apply_logits_processing(input_ids, logits, logits_processing_list, **kwargs):
|
||||
for process in logits_processing_list:
|
||||
func_args = inspect.signature(process.__call__).parameters
|
||||
if not all(arg in kwargs for arg in list(func_args.keys())[1:]):
|
||||
if not all(arg in kwargs for arg in list(func_args.keys())[3:]):
|
||||
raise ValueError(
|
||||
f"Make sure that all the required parameters: {list(func_args.keys())} for "
|
||||
f"{process.__class__} are passed to the logits processor."
|
||||
)
|
||||
logits = process(logits, **kwargs)
|
||||
if "input_ids" in func_args:
|
||||
logits = process(input_ids, logits)
|
||||
else:
|
||||
logits = process(logits, **kwargs)
|
||||
return logits
|
||||
|
||||
def check_stopping_criteria(input_ids: torch.Tensor, max_length: int, eos_token):
|
||||
def check_stopping_strings(input_ids: torch.Tensor, stop_strings: list) -> torch.BoolTensor:
|
||||
# stop_strings must be a list of lists: List[List[], List[]]
|
||||
|
||||
device = input_ids.device
|
||||
batch_size, seq_len = input_ids.shape
|
||||
finished = torch.zeros(batch_size, dtype = torch.bool, device = device)
|
||||
|
||||
for b in range(batch_size):
|
||||
row = input_ids[b]
|
||||
# check each stop token sequence
|
||||
for stop_ids in stop_strings:
|
||||
n = len(stop_ids)
|
||||
if n == 0 or n > seq_len:
|
||||
continue
|
||||
# compare tail of the generated ids to the stop sequence
|
||||
if torch.all(row[-n:] == torch.tensor(stop_ids, device = device, dtype = row.dtype)):
|
||||
finished[b] = True
|
||||
break
|
||||
|
||||
return finished
|
||||
|
||||
def check_stopping_criteria(input_ids: torch.Tensor, max_length: int, eos_token, stop_strings: tuple = None):
|
||||
|
||||
device = input_ids.device
|
||||
|
||||
if not isinstance(eos_token, torch.Tensor):
|
||||
eos_token = torch.tensor(eos_token, device=input_ids.device)
|
||||
eos_token = torch.tensor(eos_token, device = device)
|
||||
|
||||
max_len_done = input_ids.shape[1] >= max_length
|
||||
|
||||
eos_done = torch.isin(input_ids[:, -1], eos_token)
|
||||
|
||||
# finished either by lenght or eos
|
||||
finished_mask = max_len_done | eos_done
|
||||
if stop_strings is not None:
|
||||
stop_done = check_stopping_strings(input_ids, stop_strings)
|
||||
else:
|
||||
stop_done = torch.zeros(input_ids.size(0), dtype=torch.bool, device=device)
|
||||
|
||||
# finished either by lenght or eos or stop strings
|
||||
finished_mask = max_len_done | eos_done | stop_done
|
||||
|
||||
unfinished_mask = ~finished_mask
|
||||
|
||||
|
||||
@ -500,6 +500,8 @@ class HiggsAudioModel(nn.Module):
|
||||
torch.ones(kwargs["audio_num_codebooks"]) / kwargs["audio_num_codebooks"]
|
||||
)
|
||||
|
||||
self.stop_strings = [[128009], [128001]]
|
||||
|
||||
def _sample_audio_tokens(
|
||||
self,
|
||||
audio_logits: torch.Tensor,
|
||||
@ -520,7 +522,7 @@ class HiggsAudioModel(nn.Module):
|
||||
audio_eos_token_id = generation_config.generation_kwargs.get("audio_eos_token_id", None)
|
||||
|
||||
next_audio_token_logits = audio_logits.clone()[-1, :, :].float().to(device)
|
||||
next_audio_token_scores = apply_logits_processing(next_audio_token_logits, logits_processing_list)
|
||||
next_audio_token_scores = apply_logits_processing(None, next_audio_token_logits, logits_processing_list)
|
||||
|
||||
if do_sample:
|
||||
probs = nn.functional.softmax(next_audio_token_scores, dim = -1)
|
||||
@ -588,6 +590,9 @@ class HiggsAudioModel(nn.Module):
|
||||
logits_processing_list,
|
||||
device: torch.device,
|
||||
generation_mode: GenerationMode,
|
||||
torch_generator,
|
||||
is_using_cuda_graphs,
|
||||
do_sample = False,
|
||||
) -> torch.Tensor:
|
||||
"""Sample text tokens from the logits"""
|
||||
|
||||
@ -595,7 +600,7 @@ class HiggsAudioModel(nn.Module):
|
||||
next_token_logits = next_token_logits.to(input_ids.device)
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = apply_logits_processing(next_token_logits, logits_processing_list)
|
||||
next_token_scores = apply_logits_processing(input_ids, next_token_logits, logits_processing_list)
|
||||
|
||||
if generation_mode == GenerationMode.AUDIO_INIT:
|
||||
# See the audio bos token, we should start generating audio tokens
|
||||
@ -612,7 +617,17 @@ class HiggsAudioModel(nn.Module):
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
||||
|
||||
if do_sample:
|
||||
probs = nn.functional.softmax(next_token_scores, dim = -1)
|
||||
# same as for audio
|
||||
if not is_using_cuda_graphs:
|
||||
next_tokens = torch.multinomial(probs, num_samples = 1, generator=torch_generator).squeeze(1)
|
||||
else:
|
||||
next_tokens = categorical_sample(probs, generator = torch_generator)
|
||||
else:
|
||||
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
||||
|
||||
next_audio_tokens = None
|
||||
|
||||
return next_tokens, next_audio_tokens
|
||||
@ -1093,12 +1108,6 @@ class HiggsAudioModel(nn.Module):
|
||||
del model_inputs["audio_out_ids_start"]
|
||||
|
||||
if generation_config.use_cache:
|
||||
if "audio_features" in model_inputs and model_inputs["audio_features"] is not None:
|
||||
model_inputs["audio_features"] = model_inputs["audio_features"][:0, ...]
|
||||
model_inputs["audio_feature_attention_mask"] = model_inputs["audio_feature_attention_mask"][
|
||||
:0, ...
|
||||
]
|
||||
|
||||
if "audio_in_ids" in model_inputs and model_inputs["audio_in_ids"] is not None:
|
||||
model_inputs["audio_in_ids"] = None
|
||||
model_inputs["audio_in_ids_start"] = None
|
||||
@ -1159,6 +1168,9 @@ class HiggsAudioModel(nn.Module):
|
||||
logits_processing_list=logits_processing_list,
|
||||
device=input_ids.device,
|
||||
generation_mode=generation_mode,
|
||||
torch_generator = torch_generator,
|
||||
do_sample = do_sample,
|
||||
is_using_cuda_graphs = is_using_cuda_graphs
|
||||
)
|
||||
|
||||
if next_audio_tokens is not None:
|
||||
@ -1199,7 +1211,7 @@ class HiggsAudioModel(nn.Module):
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
|
||||
input_ids_full = torch.cat([input_ids_full, next_tokens[:, None]], dim=-1)
|
||||
finished, unfinished_sequences = check_stopping_criteria(input_ids_full, max_length, eos_token = eos_token_tensor)
|
||||
finished, unfinished_sequences = check_stopping_criteria(input_ids_full, max_length, eos_token = eos_token_tensor, stop_strings = self.stop_strings)
|
||||
this_peer_finished = finished.all()
|
||||
cur_len += 1
|
||||
|
||||
|
||||
@ -71,7 +71,8 @@ class HiggsTokenizer(nn.Module):
|
||||
wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
|
||||
outputs.append(wv_numpy)
|
||||
|
||||
return (None, {"waveform": torch.stack(outputs, dim = 0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only
|
||||
# currently only supports one batch size
|
||||
return (None, {"waveform": torch.cat(outputs, dim = 0).unsqueeze(0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only
|
||||
|
||||
def load_state_dict(self, sd, strict = False):
|
||||
return self.audio_tokenizer.load_state_dict(sd, strict = strict)
|
||||
|
||||
@ -22,7 +22,6 @@ from comfy.ldm.higgsv2.preprocess import (
|
||||
prepare_chatml_sample, Message, ChatMLSample, ChatMLDatasetSample, AudioContent, transcript_normalize
|
||||
)
|
||||
|
||||
AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>"
|
||||
|
||||
MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech.
|
||||
If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.
|
||||
@ -175,6 +174,8 @@ class CreateChatMLSample:
|
||||
current_role = None
|
||||
collecting_system = False
|
||||
system_buffer = []
|
||||
collecting_instruction = False
|
||||
instruction_buffer = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
@ -197,6 +198,26 @@ class CreateChatMLSample:
|
||||
collecting_system = False
|
||||
continue
|
||||
|
||||
# generation instruction start
|
||||
if "<|generation_instruction_start|>" in line:
|
||||
collecting_instruction = True
|
||||
instruction_buffer = []
|
||||
continue
|
||||
|
||||
if collecting_instruction:
|
||||
if "<|generation_instruction_end|>" in line:
|
||||
instruction_text = "\n".join(instruction_buffer)
|
||||
# include both start and end tokens
|
||||
messages.append(Message(
|
||||
role="user",
|
||||
content=f"<|generation_instruction_start|>\n{instruction_text}\n<|generation_instruction_end|>"
|
||||
))
|
||||
instruction_buffer = []
|
||||
collecting_instruction = False
|
||||
else:
|
||||
instruction_buffer.append(line)
|
||||
continue
|
||||
|
||||
# speaker lines SPEAKER-0: text
|
||||
match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE)
|
||||
if match:
|
||||
@ -221,14 +242,29 @@ class CreateChatMLSample:
|
||||
lines = all_text.splitlines()
|
||||
messages = [messages[0]] if messages[0].role == "system" else []
|
||||
current_role = None
|
||||
|
||||
for line in lines:
|
||||
match = re.match(r'\[SPEAKER\d+\]', line)
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
match = re.match(r'(\[SPEAKER\d+\])\s*(.*)', line)
|
||||
if match:
|
||||
current_role = match.group(0)
|
||||
messages.append(Message(role="user", content=line.strip()))
|
||||
current_role = match.group(1)
|
||||
content = match.group(2).strip() # only take the text after the tag
|
||||
messages.append(Message(role="user", content=f"{current_role} {content}" if content else current_role))
|
||||
else:
|
||||
if current_role and messages:
|
||||
messages[-1].content += "\n" + line.strip()
|
||||
messages[-1].content += "\n" + line
|
||||
|
||||
# dedepulicate the messages
|
||||
for idx, m in enumerate(messages):
|
||||
double_eot = "<|eot_id|><|eot_id|>"
|
||||
if double_eot in m.content:
|
||||
cut_index = m.content.index(double_eot)
|
||||
messages[idx].content = m.content[:cut_index + (len(double_eot) // 2)]
|
||||
break
|
||||
|
||||
if audio is not None:
|
||||
# for audio cloning, the first message is a transcript, second is the audio,
|
||||
# third is the request of what the model should say
|
||||
|
||||
1
nodes.py
1
nodes.py
@ -2065,7 +2065,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"KSampler": "KSampler",
|
||||
"KSamplerAdvanced": "KSampler (Advanced)",
|
||||
"AutoRegressiveGeneration": "Autoregressive Generation",
|
||||
""
|
||||
# Loaders
|
||||
"CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
|
||||
"CheckpointLoaderSimple": "Load Checkpoint",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user