mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-12 14:32:35 +08:00
Make ace step 1.5 base model work properly with default workflow. (#12337)
This commit is contained in:
parent
a1c101f861
commit
eba6c940fd
@ -1110,7 +1110,7 @@ class AceStepConditionGenerationModel(nn.Module):
|
|||||||
|
|
||||||
return encoder_hidden, encoder_mask, context_latents
|
return encoder_hidden, encoder_mask, context_latents
|
||||||
|
|
||||||
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, **kwargs):
|
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, replace_with_null_embeds=False, **kwargs):
|
||||||
text_attention_mask = None
|
text_attention_mask = None
|
||||||
lyric_attention_mask = None
|
lyric_attention_mask = None
|
||||||
refer_audio_order_mask = None
|
refer_audio_order_mask = None
|
||||||
@ -1140,6 +1140,9 @@ class AceStepConditionGenerationModel(nn.Module):
|
|||||||
src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes
|
src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if replace_with_null_embeds:
|
||||||
|
enc_hidden[:] = self.null_condition_emb.to(enc_hidden)
|
||||||
|
|
||||||
out = self.decoder(hidden_states=x,
|
out = self.decoder(hidden_states=x,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
timestep_r=timestep,
|
timestep_r=timestep,
|
||||||
|
|||||||
@ -1552,6 +1552,8 @@ class ACEStep15(BaseModel):
|
|||||||
|
|
||||||
cross_attn = kwargs.get("cross_attn", None)
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
|
if torch.count_nonzero(cross_attn) == 0:
|
||||||
|
out['replace_with_null_embeds'] = comfy.conds.CONDConstant(True)
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
|
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user