mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 00:12:33 +08:00
Iimprovements to ACE-Steps 1.5 text encoding (part 2) (#12350)
This commit is contained in:
parent
62315fbb15
commit
baf8c87455
@ -3,6 +3,7 @@ import comfy.text_encoders.llama
|
|||||||
from comfy import sd1_clip
|
from comfy import sd1_clip
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
|
from tqdm.auto import trange
|
||||||
import yaml
|
import yaml
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
@ -23,6 +24,8 @@ def sample_manual_loop_no_classes(
|
|||||||
audio_end_id: int = 215669,
|
audio_end_id: int = 215669,
|
||||||
eos_token_id: int = 151645,
|
eos_token_id: int = 151645,
|
||||||
):
|
):
|
||||||
|
if ids is None:
|
||||||
|
return []
|
||||||
device = model.execution_device
|
device = model.execution_device
|
||||||
|
|
||||||
if execution_dtype is None:
|
if execution_dtype is None:
|
||||||
@ -32,6 +35,7 @@ def sample_manual_loop_no_classes(
|
|||||||
execution_dtype = torch.float32
|
execution_dtype = torch.float32
|
||||||
|
|
||||||
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
||||||
|
embeds_batch = embeds.shape[0]
|
||||||
for i, t in enumerate(paddings):
|
for i, t in enumerate(paddings):
|
||||||
attention_mask[i, :t] = 0
|
attention_mask[i, :t] = 0
|
||||||
attention_mask[i, t:] = 1
|
attention_mask[i, t:] = 1
|
||||||
@ -41,22 +45,27 @@ def sample_manual_loop_no_classes(
|
|||||||
generator = torch.Generator(device=device)
|
generator = torch.Generator(device=device)
|
||||||
generator.manual_seed(seed)
|
generator.manual_seed(seed)
|
||||||
model_config = model.transformer.model.config
|
model_config = model.transformer.model.config
|
||||||
|
past_kv_shape = [embeds_batch, model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim]
|
||||||
|
|
||||||
for x in range(model_config.num_hidden_layers):
|
for x in range(model_config.num_hidden_layers):
|
||||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
past_key_values.append((torch.empty(past_kv_shape, device=device, dtype=execution_dtype), torch.empty(past_kv_shape, device=device, dtype=execution_dtype), 0))
|
||||||
|
|
||||||
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
|
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
|
||||||
|
|
||||||
for step in range(max_new_tokens):
|
for step in trange(max_new_tokens, desc="LM sampling"):
|
||||||
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
|
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
|
||||||
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
|
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
|
||||||
past_key_values = outputs[2]
|
past_key_values = outputs[2]
|
||||||
|
|
||||||
cond_logits = next_token_logits[0:1]
|
if cfg_scale != 1.0:
|
||||||
uncond_logits = next_token_logits[1:2]
|
cond_logits = next_token_logits[0:1]
|
||||||
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
|
uncond_logits = next_token_logits[1:2]
|
||||||
|
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
|
||||||
|
else:
|
||||||
|
cfg_logits = next_token_logits[0:1]
|
||||||
|
|
||||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
use_eos_score = eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step
|
||||||
|
if use_eos_score:
|
||||||
eos_score = cfg_logits[:, eos_token_id].clone()
|
eos_score = cfg_logits[:, eos_token_id].clone()
|
||||||
|
|
||||||
remove_logit_value = torch.finfo(cfg_logits.dtype).min
|
remove_logit_value = torch.finfo(cfg_logits.dtype).min
|
||||||
@ -64,7 +73,7 @@ def sample_manual_loop_no_classes(
|
|||||||
cfg_logits[:, :audio_start_id] = remove_logit_value
|
cfg_logits[:, :audio_start_id] = remove_logit_value
|
||||||
cfg_logits[:, audio_end_id:] = remove_logit_value
|
cfg_logits[:, audio_end_id:] = remove_logit_value
|
||||||
|
|
||||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
if use_eos_score:
|
||||||
cfg_logits[:, eos_token_id] = eos_score
|
cfg_logits[:, eos_token_id] = eos_score
|
||||||
|
|
||||||
if top_k is not None and top_k > 0:
|
if top_k is not None and top_k > 0:
|
||||||
@ -93,8 +102,8 @@ def sample_manual_loop_no_classes(
|
|||||||
break
|
break
|
||||||
|
|
||||||
embed, _, _, _ = model.process_tokens([[token]], device)
|
embed, _, _, _ = model.process_tokens([[token]], device)
|
||||||
embeds = embed.repeat(2, 1, 1)
|
embeds = embed.repeat(embeds_batch, 1, 1)
|
||||||
attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
|
attention_mask = torch.cat([attention_mask, torch.ones((embeds_batch, 1), device=device, dtype=attention_mask.dtype)], dim=1)
|
||||||
|
|
||||||
output_audio_codes.append(token - audio_start_id)
|
output_audio_codes.append(token - audio_start_id)
|
||||||
progress_bar.update_absolute(step)
|
progress_bar.update_absolute(step)
|
||||||
@ -104,22 +113,29 @@ def sample_manual_loop_no_classes(
|
|||||||
|
|
||||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
|
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
|
||||||
positive = [[token for token, _ in inner_list] for inner_list in positive]
|
positive = [[token for token, _ in inner_list] for inner_list in positive]
|
||||||
negative = [[token for token, _ in inner_list] for inner_list in negative]
|
|
||||||
positive = positive[0]
|
positive = positive[0]
|
||||||
negative = negative[0]
|
|
||||||
|
|
||||||
neg_pad = 0
|
if cfg_scale != 1.0:
|
||||||
if len(negative) < len(positive):
|
negative = [[token for token, _ in inner_list] for inner_list in negative]
|
||||||
neg_pad = (len(positive) - len(negative))
|
negative = negative[0]
|
||||||
negative = [model.special_tokens["pad"]] * neg_pad + negative
|
|
||||||
|
|
||||||
pos_pad = 0
|
neg_pad = 0
|
||||||
if len(negative) > len(positive):
|
if len(negative) < len(positive):
|
||||||
pos_pad = (len(negative) - len(positive))
|
neg_pad = (len(positive) - len(negative))
|
||||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
negative = [model.special_tokens["pad"]] * neg_pad + negative
|
||||||
|
|
||||||
paddings = [pos_pad, neg_pad]
|
pos_pad = 0
|
||||||
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
if len(negative) > len(positive):
|
||||||
|
pos_pad = (len(negative) - len(positive))
|
||||||
|
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||||
|
|
||||||
|
paddings = [pos_pad, neg_pad]
|
||||||
|
ids = [positive, negative]
|
||||||
|
else:
|
||||||
|
paddings = []
|
||||||
|
ids = [positive]
|
||||||
|
|
||||||
|
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||||
|
|
||||||
|
|
||||||
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
@ -129,12 +145,12 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
|
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
|
||||||
user_metas = {
|
user_metas = {
|
||||||
k: kwargs.pop(k)
|
k: kwargs.pop(k)
|
||||||
for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption")
|
for k in ("bpm", "duration", "keyscale", "timesignature", "language")
|
||||||
if k in kwargs
|
if k in kwargs
|
||||||
}
|
}
|
||||||
timesignature = user_metas.get("timesignature")
|
timesignature = user_metas.get("timesignature")
|
||||||
if isinstance(timesignature, str) and timesignature.endswith("/4"):
|
if isinstance(timesignature, str) and timesignature.endswith("/4"):
|
||||||
user_metas["timesignature"] = timesignature.rsplit("/", 1)[0]
|
user_metas["timesignature"] = timesignature[:-2]
|
||||||
user_metas = {
|
user_metas = {
|
||||||
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
|
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
|
||||||
for k, v in user_metas.items()
|
for k, v in user_metas.items()
|
||||||
@ -147,8 +163,11 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml
|
return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml
|
||||||
|
|
||||||
def _metas_to_cap(self, **kwargs) -> str:
|
def _metas_to_cap(self, **kwargs) -> str:
|
||||||
use_keys = ("bpm", "duration", "keyscale", "timesignature")
|
use_keys = ("bpm", "timesignature", "keyscale", "duration")
|
||||||
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
|
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
|
||||||
|
timesignature = user_metas.get("timesignature")
|
||||||
|
if isinstance(timesignature, str) and timesignature.endswith("/4"):
|
||||||
|
user_metas["timesignature"] = timesignature[:-2]
|
||||||
duration = user_metas["duration"]
|
duration = user_metas["duration"]
|
||||||
if duration == "N/A":
|
if duration == "N/A":
|
||||||
user_metas["duration"] = "30 seconds"
|
user_metas["duration"] = "30 seconds"
|
||||||
@ -159,9 +178,13 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
|
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
text = text.strip()
|
||||||
|
text_negative = kwargs.get("caption_negative", text).strip()
|
||||||
lyrics = kwargs.get("lyrics", "")
|
lyrics = kwargs.get("lyrics", "")
|
||||||
|
lyrics_negative = kwargs.get("lyrics_negative", lyrics)
|
||||||
duration = kwargs.get("duration", 120)
|
duration = kwargs.get("duration", 120)
|
||||||
|
if isinstance(duration, str):
|
||||||
|
duration = float(duration.split(None, 1)[0])
|
||||||
language = kwargs.get("language")
|
language = kwargs.get("language")
|
||||||
seed = kwargs.get("seed", 0)
|
seed = kwargs.get("seed", 0)
|
||||||
|
|
||||||
@ -171,21 +194,46 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
top_p = kwargs.get("top_p", 0.9)
|
top_p = kwargs.get("top_p", 0.9)
|
||||||
top_k = kwargs.get("top_k", 0.0)
|
top_k = kwargs.get("top_k", 0.0)
|
||||||
|
|
||||||
|
|
||||||
duration = math.ceil(duration)
|
duration = math.ceil(duration)
|
||||||
kwargs["duration"] = duration
|
kwargs["duration"] = duration
|
||||||
|
tokens_duration = duration * 5
|
||||||
|
min_tokens = int(kwargs.get("min_tokens", tokens_duration))
|
||||||
|
max_tokens = int(kwargs.get("max_tokens", tokens_duration))
|
||||||
|
|
||||||
|
metas_negative = {
|
||||||
|
k.rsplit("_", 1)[0]: kwargs.pop(k)
|
||||||
|
for k in ("bpm_negative", "duration_negative", "keyscale_negative", "timesignature_negative", "language_negative", "caption_negative")
|
||||||
|
if k in kwargs
|
||||||
|
}
|
||||||
|
if not kwargs.get("use_negative_caption"):
|
||||||
|
_ = metas_negative.pop("caption", None)
|
||||||
|
|
||||||
cot_text = self._metas_to_cot(caption = text, **kwargs)
|
cot_text = self._metas_to_cot(caption = text, **kwargs)
|
||||||
|
cot_text_negative = "<think>\n</think>" if not metas_negative else self._metas_to_cot(**metas_negative)
|
||||||
meta_cap = self._metas_to_cap(**kwargs)
|
meta_cap = self._metas_to_cap(**kwargs)
|
||||||
|
|
||||||
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n"
|
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n"
|
||||||
|
lyrics_template = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>"
|
||||||
|
qwen3_06b_template = "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>"
|
||||||
|
|
||||||
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True)
|
llm_prompts = {
|
||||||
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "<think>\n</think>"), disable_weights=True)
|
"lm_prompt": lm_template.format(text, lyrics.strip(), cot_text),
|
||||||
|
"lm_prompt_negative": lm_template.format(text_negative, lyrics_negative.strip(), cot_text_negative),
|
||||||
|
"lyrics": lyrics_template.format(language if language is not None else "", lyrics),
|
||||||
|
"qwen3_06b": qwen3_06b_template.format(text, meta_cap),
|
||||||
|
}
|
||||||
|
|
||||||
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
|
out = {
|
||||||
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
|
prompt_key: self.qwen3_06b.tokenize_with_weights(
|
||||||
out["lm_metadata"] = {"min_tokens": duration * 5,
|
prompt,
|
||||||
|
prompt_key == "qwen3_06b" and return_word_ids,
|
||||||
|
disable_weights = True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
for prompt_key, prompt in llm_prompts.items()
|
||||||
|
}
|
||||||
|
out["lm_metadata"] = {"min_tokens": min_tokens,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
"seed": seed,
|
"seed": seed,
|
||||||
"generate_audio_codes": generate_audio_codes,
|
"generate_audio_codes": generate_audio_codes,
|
||||||
"cfg_scale": cfg_scale,
|
"cfg_scale": cfg_scale,
|
||||||
@ -252,7 +300,7 @@ class ACE15TEModel(torch.nn.Module):
|
|||||||
|
|
||||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||||
if lm_metadata["generate_audio_codes"]:
|
if lm_metadata["generate_audio_codes"]:
|
||||||
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
|
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["max_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
|
||||||
out["audio_codes"] = [audio_codes]
|
out["audio_codes"] = [audio_codes]
|
||||||
|
|
||||||
return base_out, None, out
|
return base_out, None, out
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user