mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 19:42:34 +08:00
Fix crash with ace step 1.5 (#12264)
This commit is contained in:
parent
855849c658
commit
a31681564d
@ -19,6 +19,7 @@ def sample_manual_loop_no_classes(
|
|||||||
min_tokens: int = 1,
|
min_tokens: int = 1,
|
||||||
max_new_tokens: int = 2048,
|
max_new_tokens: int = 2048,
|
||||||
audio_start_id: int = 151669, # The cutoff ID for audio codes
|
audio_start_id: int = 151669, # The cutoff ID for audio codes
|
||||||
|
audio_end_id: int = 215669,
|
||||||
eos_token_id: int = 151645,
|
eos_token_id: int = 151645,
|
||||||
):
|
):
|
||||||
device = model.execution_device
|
device = model.execution_device
|
||||||
@ -60,6 +61,7 @@ def sample_manual_loop_no_classes(
|
|||||||
remove_logit_value = torch.finfo(cfg_logits.dtype).min
|
remove_logit_value = torch.finfo(cfg_logits.dtype).min
|
||||||
# Only generate audio tokens
|
# Only generate audio tokens
|
||||||
cfg_logits[:, :audio_start_id] = remove_logit_value
|
cfg_logits[:, :audio_start_id] = remove_logit_value
|
||||||
|
cfg_logits[:, audio_end_id:] = remove_logit_value
|
||||||
|
|
||||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
||||||
cfg_logits[:, eos_token_id] = eos_score
|
cfg_logits[:, eos_token_id] = eos_score
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user