diff --git a/comfy/autoregressive_sampling.py b/comfy/autoregressive_sampling.py index 424907fb6..8a2d3caff 100644 --- a/comfy/autoregressive_sampling.py +++ b/comfy/autoregressive_sampling.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math import copy import torch import inspect @@ -56,9 +57,9 @@ class TopKLogits: class TemperatureLogitsWarper: def __init__(self, temperature: float): - if not (temperature > 0): - raise ValueError(f"`temperature` (={temperature}) must be positive temperature > 0") + raise ValueError(f"`temperature` (={temperature}) must be a positive number > 0") + self.temperature = temperature def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: @@ -86,10 +87,30 @@ class TopPLogitsWarper: scores_processed = scores.masked_fill(indices_to_remove, self.filter_value) return scores_processed +class MinLengthLogitsProcessor: + def __init__(self, min_length: int, eos_token_id: torch.Tensor): + self.min_length = min_length + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + + if input_ids is None: + return scores + + vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) + eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id) + scores_processed = scores.clone() + if input_ids.shape[-1] < self.min_length: + scores_processed = torch.where(eos_token_mask, -math.inf, scores) + return scores_processed + def get_logits_processing(config: GenerationConfig): # TODO: add support for beam search with diversity penalty logits_processors = [] + if config._eos_token_tensor is not None and config.min_length > 1: + logits_processors.append(MinLengthLogitsProcessor(config.min_length, config._eos_token_tensor)) + if config.top_k is not None and config.top_k != 0: logits_processors.append(TopKLogits(config.top_k)) @@ -101,28 +122,59 @@ def get_logits_processing(config: GenerationConfig): return logits_processors -def apply_logits_processing(logits, logits_processing_list, **kwargs): +def apply_logits_processing(input_ids, logits, logits_processing_list, **kwargs): for process in logits_processing_list: func_args = inspect.signature(process.__call__).parameters - if not all(arg in kwargs for arg in list(func_args.keys())[1:]): + if not all(arg in kwargs for arg in list(func_args.keys())[3:]): raise ValueError( f"Make sure that all the required parameters: {list(func_args.keys())} for " f"{process.__class__} are passed to the logits processor." ) - logits = process(logits, **kwargs) + if "input_ids" in func_args: + logits = process(input_ids, logits) + else: + logits = process(logits, **kwargs) return logits -def check_stopping_criteria(input_ids: torch.Tensor, max_length: int, eos_token): +def check_stopping_strings(input_ids: torch.Tensor, stop_strings: list) -> torch.BoolTensor: + # stop_strings must be a list of lists: List[List[], List[]] + + device = input_ids.device + batch_size, seq_len = input_ids.shape + finished = torch.zeros(batch_size, dtype = torch.bool, device = device) + + for b in range(batch_size): + row = input_ids[b] + # check each stop token sequence + for stop_ids in stop_strings: + n = len(stop_ids) + if n == 0 or n > seq_len: + continue + # compare tail of the generated ids to the stop sequence + if torch.all(row[-n:] == torch.tensor(stop_ids, device = device, dtype = row.dtype)): + finished[b] = True + break + + return finished + +def check_stopping_criteria(input_ids: torch.Tensor, max_length: int, eos_token, stop_strings: tuple = None): + + device = input_ids.device if not isinstance(eos_token, torch.Tensor): - eos_token = torch.tensor(eos_token, device=input_ids.device) + eos_token = torch.tensor(eos_token, device = device) max_len_done = input_ids.shape[1] >= max_length eos_done = torch.isin(input_ids[:, -1], eos_token) - # finished either by lenght or eos - finished_mask = max_len_done | eos_done + if stop_strings is not None: + stop_done = check_stopping_strings(input_ids, stop_strings) + else: + stop_done = torch.zeros(input_ids.size(0), dtype=torch.bool, device=device) + + # finished either by lenght or eos or stop strings + finished_mask = max_len_done | eos_done | stop_done unfinished_mask = ~finished_mask diff --git a/comfy/ldm/higgsv2/model.py b/comfy/ldm/higgsv2/model.py index 16180b657..1de6d61ac 100644 --- a/comfy/ldm/higgsv2/model.py +++ b/comfy/ldm/higgsv2/model.py @@ -500,6 +500,8 @@ class HiggsAudioModel(nn.Module): torch.ones(kwargs["audio_num_codebooks"]) / kwargs["audio_num_codebooks"] ) + self.stop_strings = [[128009], [128001]] + def _sample_audio_tokens( self, audio_logits: torch.Tensor, @@ -520,7 +522,7 @@ class HiggsAudioModel(nn.Module): audio_eos_token_id = generation_config.generation_kwargs.get("audio_eos_token_id", None) next_audio_token_logits = audio_logits.clone()[-1, :, :].float().to(device) - next_audio_token_scores = apply_logits_processing(next_audio_token_logits, logits_processing_list) + next_audio_token_scores = apply_logits_processing(None, next_audio_token_logits, logits_processing_list) if do_sample: probs = nn.functional.softmax(next_audio_token_scores, dim = -1) @@ -588,6 +590,9 @@ class HiggsAudioModel(nn.Module): logits_processing_list, device: torch.device, generation_mode: GenerationMode, + torch_generator, + is_using_cuda_graphs, + do_sample = False, ) -> torch.Tensor: """Sample text tokens from the logits""" @@ -595,7 +600,7 @@ class HiggsAudioModel(nn.Module): next_token_logits = next_token_logits.to(input_ids.device) # pre-process distribution - next_token_scores = apply_logits_processing(next_token_logits, logits_processing_list) + next_token_scores = apply_logits_processing(input_ids, next_token_logits, logits_processing_list) if generation_mode == GenerationMode.AUDIO_INIT: # See the audio bos token, we should start generating audio tokens @@ -612,7 +617,17 @@ class HiggsAudioModel(nn.Module): device=device, ) else: - next_tokens = torch.argmax(next_token_scores, dim=-1) + + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim = -1) + # same as for audio + if not is_using_cuda_graphs: + next_tokens = torch.multinomial(probs, num_samples = 1, generator=torch_generator).squeeze(1) + else: + next_tokens = categorical_sample(probs, generator = torch_generator) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + next_audio_tokens = None return next_tokens, next_audio_tokens @@ -1093,12 +1108,6 @@ class HiggsAudioModel(nn.Module): del model_inputs["audio_out_ids_start"] if generation_config.use_cache: - if "audio_features" in model_inputs and model_inputs["audio_features"] is not None: - model_inputs["audio_features"] = model_inputs["audio_features"][:0, ...] - model_inputs["audio_feature_attention_mask"] = model_inputs["audio_feature_attention_mask"][ - :0, ... - ] - if "audio_in_ids" in model_inputs and model_inputs["audio_in_ids"] is not None: model_inputs["audio_in_ids"] = None model_inputs["audio_in_ids_start"] = None @@ -1159,6 +1168,9 @@ class HiggsAudioModel(nn.Module): logits_processing_list=logits_processing_list, device=input_ids.device, generation_mode=generation_mode, + torch_generator = torch_generator, + do_sample = do_sample, + is_using_cuda_graphs = is_using_cuda_graphs ) if next_audio_tokens is not None: @@ -1199,7 +1211,7 @@ class HiggsAudioModel(nn.Module): input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) input_ids_full = torch.cat([input_ids_full, next_tokens[:, None]], dim=-1) - finished, unfinished_sequences = check_stopping_criteria(input_ids_full, max_length, eos_token = eos_token_tensor) + finished, unfinished_sequences = check_stopping_criteria(input_ids_full, max_length, eos_token = eos_token_tensor, stop_strings = self.stop_strings) this_peer_finished = finished.all() cur_len += 1 diff --git a/comfy/text_encoders/higgsv2.py b/comfy/text_encoders/higgsv2.py index b861ac0e5..ad22a3836 100644 --- a/comfy/text_encoders/higgsv2.py +++ b/comfy/text_encoders/higgsv2.py @@ -71,7 +71,8 @@ class HiggsTokenizer(nn.Module): wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0] outputs.append(wv_numpy) - return (None, {"waveform": torch.stack(outputs, dim = 0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only + # currently only supports one batch size + return (None, {"waveform": torch.cat(outputs, dim = 0).unsqueeze(0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only def load_state_dict(self, sd, strict = False): return self.audio_tokenizer.load_state_dict(sd, strict = strict) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 8651b6a57..63e853a24 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -22,7 +22,6 @@ from comfy.ldm.higgsv2.preprocess import ( prepare_chatml_sample, Message, ChatMLSample, ChatMLDatasetSample, AudioContent, transcript_normalize ) -AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>" MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech. If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice. @@ -175,6 +174,8 @@ class CreateChatMLSample: current_role = None collecting_system = False system_buffer = [] + collecting_instruction = False + instruction_buffer = [] for line in lines: line = line.strip() @@ -197,6 +198,26 @@ class CreateChatMLSample: collecting_system = False continue + # generation instruction start + if "<|generation_instruction_start|>" in line: + collecting_instruction = True + instruction_buffer = [] + continue + + if collecting_instruction: + if "<|generation_instruction_end|>" in line: + instruction_text = "\n".join(instruction_buffer) + # include both start and end tokens + messages.append(Message( + role="user", + content=f"<|generation_instruction_start|>\n{instruction_text}\n<|generation_instruction_end|>" + )) + instruction_buffer = [] + collecting_instruction = False + else: + instruction_buffer.append(line) + continue + # speaker lines SPEAKER-0: text match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE) if match: @@ -221,14 +242,29 @@ class CreateChatMLSample: lines = all_text.splitlines() messages = [messages[0]] if messages[0].role == "system" else [] current_role = None + for line in lines: - match = re.match(r'\[SPEAKER\d+\]', line) + line = line.strip() + if not line: + continue + + match = re.match(r'(\[SPEAKER\d+\])\s*(.*)', line) if match: - current_role = match.group(0) - messages.append(Message(role="user", content=line.strip())) + current_role = match.group(1) + content = match.group(2).strip() # only take the text after the tag + messages.append(Message(role="user", content=f"{current_role} {content}" if content else current_role)) else: if current_role and messages: - messages[-1].content += "\n" + line.strip() + messages[-1].content += "\n" + line + + # dedepulicate the messages + for idx, m in enumerate(messages): + double_eot = "<|eot_id|><|eot_id|>" + if double_eot in m.content: + cut_index = m.content.index(double_eot) + messages[idx].content = m.content[:cut_index + (len(double_eot) // 2)] + break + if audio is not None: # for audio cloning, the first message is a transcript, second is the audio, # third is the request of what the model should say diff --git a/nodes.py b/nodes.py index 024c31ac6..35c83500e 100644 --- a/nodes.py +++ b/nodes.py @@ -2065,7 +2065,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "KSampler": "KSampler", "KSamplerAdvanced": "KSampler (Advanced)", "AutoRegressiveGeneration": "Autoregressive Generation", - "" # Loaders "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)", "CheckpointLoaderSimple": "Load Checkpoint",