From a246cc02b274104d5f656b68ce505354c164aef8 Mon Sep 17 00:00:00 2001 From: blepping <157360029+blepping@users.noreply.github.com> Date: Wed, 4 Feb 2026 22:17:37 -0700 Subject: [PATCH] Improvements to ACE-Steps 1.5 text encoding (#12283) --- comfy/text_encoders/ace15.py | 56 +++++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 11 deletions(-) diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 74e62733e..00dd5ba90 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -3,6 +3,7 @@ import comfy.text_encoders.llama from comfy import sd1_clip import torch import math +import yaml import comfy.utils @@ -125,14 +126,43 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer) + def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str: + user_metas = { + k: kwargs.pop(k) + for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption") + if k in kwargs + } + timesignature = user_metas.get("timesignature") + if isinstance(timesignature, str) and timesignature.endswith("/4"): + user_metas["timesignature"] = timesignature.rsplit("/", 1)[0] + user_metas = { + k: v if not isinstance(v, str) or not v.isdigit() else int(v) + for k, v in user_metas.items() + if v not in {"unspecified", None} + } + if len(user_metas): + meta_yaml = yaml.dump(user_metas, allow_unicode=True, sort_keys=True).strip() + else: + meta_yaml = "" + return f"\n{meta_yaml}\n" if not return_yaml else meta_yaml + + def _metas_to_cap(self, **kwargs) -> str: + use_keys = ("bpm", "duration", "keyscale", "timesignature") + user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys } + duration = user_metas["duration"] + if duration == "N/A": + user_metas["duration"] = "30 seconds" + elif isinstance(duration, (str, int, float)): + user_metas["duration"] = f"{math.ceil(float(duration))} seconds" + else: + raise TypeError("Unexpected type for duration key, must be str, int or float") + return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys) + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): out = {} lyrics = kwargs.get("lyrics", "") - bpm = kwargs.get("bpm", 120) duration = kwargs.get("duration", 120) - keyscale = kwargs.get("keyscale", "C major") - timesignature = kwargs.get("timesignature", 2) - language = kwargs.get("language", "en") + language = kwargs.get("language") seed = kwargs.get("seed", 0) generate_audio_codes = kwargs.get("generate_audio_codes", True) @@ -141,16 +171,20 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer): top_p = kwargs.get("top_p", 0.9) top_k = kwargs.get("top_k", 0.0) + duration = math.ceil(duration) - meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature) - lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n\n{}\n\n\n<|im_end|>\n" + kwargs["duration"] = duration - meta_cap = '- bpm: {}\n- timesignature: {}\n- keyscale: {}\n- duration: {}\n'.format(bpm, timesignature, keyscale, duration) - out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, meta_lm), disable_weights=True) - out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, ""), disable_weights=True) + cot_text = self._metas_to_cot(caption = text, **kwargs) + meta_cap = self._metas_to_cap(**kwargs) - out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs) - out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs) + lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n" + + out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True) + out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "\n"), disable_weights=True) + + out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs) + out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs) out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed, "generate_audio_codes": generate_audio_codes,