This commit is contained in:
Subarasheese 2026-03-15 19:29:57 +01:00 committed by GitHub
commit 46bce0e4b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 1 deletions

View File

@ -1671,6 +1671,35 @@ class ACEStep15(BaseModel):
refer_audio = torch.cat([refer_audio.to(pad), pad[:, :, refer_audio.shape[2]:]], dim=2)
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
audio_cover_strength = kwargs.get('audio_cover_strength', 1.0)
is_cover_mode = out.get('is_covers', comfy.conds.CONDConstant(None)).cond != False
if audio_cover_strength < 1.0 and is_cover_mode and self.current_patcher is not None:
if not self.current_patcher.get_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, 'ace_step_cover_strength'):
_strength = audio_cover_strength
def audio_cover_strength_wrapper(executor, x, timestep, model_options={}, seed=None):
sample_sigmas = model_options.get('transformer_options', {}).get('sample_sigmas', None)
if sample_sigmas is not None:
current_sigma = float(timestep.max())
max_sigma = float(sample_sigmas[0])
min_sigma = float(sample_sigmas[-1])
sigma_range = max_sigma - min_sigma
if sigma_range > 0:
progress = 1.0 - (current_sigma - min_sigma) / sigma_range
if progress >= _strength:
conds = model_options.get('conds', None)
if conds is not None:
for cond_list in conds.values():
for cond in cond_list:
if 'model_conds' in cond and 'is_covers' in cond['model_conds']:
cond['model_conds']['is_covers'] = comfy.conds.CONDConstant(False)
return executor(x, timestep, model_options, seed)
self.current_patcher.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.PREDICT_NOISE,
'ace_step_cover_strength',
audio_cover_strength_wrapper
)
return out
class Omnigen2(BaseModel):

View File

@ -50,14 +50,16 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
io.Float.Input("audio_cover_strength", default=1.0, min=0.0, max=1.0, step=0.01, advanced=True, tooltip="Controls how many denoising steps use LM code conditioning. 1.0 = all steps, 0.5 = first half only."),
],
outputs=[io.Conditioning.Output()],
)
@classmethod
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p, audio_cover_strength) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
conditioning = clip.encode_from_tokens_scheduled(tokens)
conditioning = node_helpers.conditioning_set_values(conditioning, {"audio_cover_strength": audio_cover_strength})
return io.NodeOutput(conditioning)