mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-18 10:20:16 +08:00
82 lines
3.1 KiB
Python
82 lines
3.1 KiB
Python
import os
|
|
import torch
|
|
import comfy.ops
|
|
import torch.nn as nn
|
|
from transformers import AutoTokenizer
|
|
from comfy.ldm.higgsv2.tokenizer import HiggsAudioTokenizer
|
|
from comfy.ldm.higgsv2.preprocess import HiggsAudioSampleCollator
|
|
|
|
class DummyTokenizer:
|
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
|
pass
|
|
|
|
def revert_delay_pattern_vectorized(data: torch.Tensor) -> torch.Tensor:
|
|
num_codebooks, total_len = data.shape
|
|
seq_len = total_len - num_codebooks + 1
|
|
|
|
col_idx = torch.arange(seq_len, device=data.device)[None, :] \
|
|
+ torch.arange(num_codebooks, device=data.device)[:, None]
|
|
out = data[torch.arange(num_codebooks)[:, None], col_idx]
|
|
return out
|
|
|
|
class HiggsTokenizer(nn.Module):
|
|
def __init__(self, device, dtype, model_options={}, **kwargs):
|
|
super().__init__()
|
|
|
|
self.dtype = torch.float32
|
|
self.device = device
|
|
self.dtypes = [torch.float32]
|
|
|
|
here = os.path.dirname(__file__)
|
|
tokenizer_path = os.path.join(here, "higgs_text_tokenizer")
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
|
scaled_fp8 = model_options.get("scaled_fp8", None)
|
|
|
|
if scaled_fp8 is not None:
|
|
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
|
else:
|
|
operations = comfy.ops.manual_cast
|
|
|
|
self.audio_codebook_size = 1024
|
|
self.audio_tokenizer = HiggsAudioTokenizer(device = device, dtype = dtype, operations = operations)
|
|
|
|
if scaled_fp8 is not None:
|
|
self.audio_tokenizer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
|
|
|
|
self.collator = HiggsAudioSampleCollator(
|
|
audio_in_token_id = 128015,
|
|
audio_out_token_id = 128016,
|
|
audio_stream_bos_id = 1024,
|
|
audio_stream_eos_id = 1025,
|
|
pad_token_id = 128001,
|
|
return_audio_in_tokens = False,
|
|
use_delay_pattern = True,
|
|
audio_num_codebooks = 8,
|
|
round_to = 1,
|
|
)
|
|
|
|
postfix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
self.postfix = postfix + "<|audio_out_bos|>" # force audio generation
|
|
|
|
def decode_tokens(self, audio_tokens):
|
|
outputs = []
|
|
|
|
# due to instability issues, I had to convert the audio tokenizer to float32, avoiding outputing nans
|
|
self.audio_tokenizer = self.audio_tokenizer.to(self.dtype)
|
|
torch.cuda.synchronize()
|
|
|
|
for audio in audio_tokens:
|
|
vq_code = revert_delay_pattern_vectorized(audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
|
|
wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
|
|
outputs.append(wv_numpy)
|
|
|
|
# currently only supports one batch size
|
|
return (None, {"waveform": torch.cat(outputs, dim = 0).unsqueeze(0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only
|
|
|
|
def load_state_dict(self, sd, strict = False):
|
|
return self.audio_tokenizer.load_state_dict(sd, strict = strict)
|
|
|
|
def state_dict(self):
|
|
return self.audio_tokenizer.state_dict()
|