mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-16 01:00:49 +08:00
additional styling
This commit is contained in:
parent
57c15f970c
commit
f8d48917ba
@ -366,13 +366,13 @@ def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
|
||||
audio_contents.append(content)
|
||||
if role == "user" or role == "system":
|
||||
text_tokens = tokenizer.encode(
|
||||
f"<|audio_bos|><|AUDIO|><|audio_eos|>",
|
||||
"<|audio_bos|><|AUDIO|><|audio_eos|>",
|
||||
add_special_tokens=False,
|
||||
)
|
||||
input_tokens.extend(text_tokens)
|
||||
elif role == "assistant":
|
||||
text_tokens = tokenizer.encode(
|
||||
f"<|audio_out_bos|><|AUDIO_OUT|><|audio_eos|>",
|
||||
"<|audio_out_bos|><|AUDIO_OUT|><|audio_eos|>",
|
||||
add_special_tokens=False,
|
||||
)
|
||||
input_tokens.extend(text_tokens)
|
||||
@ -587,7 +587,7 @@ class HiggsAudioSampleCollator:
|
||||
# I tried to remove the for-loop in original implementation
|
||||
# but to do batching with padding caused problem so I turned it into a list compre.
|
||||
lengths = [seg.shape[1] for seg in audio_in_ids_l]
|
||||
aug_lengths = [l + 2 for l in lengths]
|
||||
aug_lengths = [length + 2 for length in lengths]
|
||||
audio_in_ids_start = torch.cumsum(
|
||||
torch.tensor([0] + aug_lengths[:-1], dtype=torch.long), dim=0
|
||||
)
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
Loading…
Reference in New Issue
Block a user