ComfyUI/comfy/text_encoders/higgsv2.py
2025-09-09 01:07:36 +03:00

82 lines
3.1 KiB
Python

import os
import torch
import comfy.ops
import torch.nn as nn
from transformers import AutoTokenizer
from comfy.ldm.higgsv2.tokenizer import HiggsAudioTokenizer
from comfy.ldm.higgsv2.preprocess import HiggsAudioSampleCollator
class DummyTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
pass
def revert_delay_pattern_vectorized(data: torch.Tensor) -> torch.Tensor:
num_codebooks, total_len = data.shape
seq_len = total_len - num_codebooks + 1
col_idx = torch.arange(seq_len, device=data.device)[None, :] \
+ torch.arange(num_codebooks, device=data.device)[:, None]
out = data[torch.arange(num_codebooks)[:, None], col_idx]
return out
class HiggsTokenizer(nn.Module):
def __init__(self, device, dtype, model_options={}, **kwargs):
super().__init__()
self.dtype = torch.float32
self.device = device
self.dtypes = [torch.float32]
here = os.path.dirname(__file__)
tokenizer_path = os.path.join(here, "higgs_text_tokenizer")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
else:
operations = comfy.ops.manual_cast
self.audio_codebook_size = 1024
self.audio_tokenizer = HiggsAudioTokenizer(device = device, dtype = dtype, operations = operations)
if scaled_fp8 is not None:
self.audio_tokenizer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
self.collator = HiggsAudioSampleCollator(
audio_in_token_id = 128015,
audio_out_token_id = 128016,
audio_stream_bos_id = 1024,
audio_stream_eos_id = 1025,
pad_token_id = 128001,
return_audio_in_tokens = False,
use_delay_pattern = True,
audio_num_codebooks = 8,
round_to = 1,
)
postfix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
self.postfix = postfix + "<|audio_out_bos|>" # force audio generation
def decode_tokens(self, audio_tokens):
outputs = []
# due to instability issues, I had to convert the audio tokenizer to float32, avoiding outputing nans
self.audio_tokenizer = self.audio_tokenizer.to(self.dtype)
torch.cuda.synchronize()
for audio in audio_tokens:
vq_code = revert_delay_pattern_vectorized(audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
outputs.append(wv_numpy)
# currently only supports one batch size
return (None, {"waveform": torch.cat(outputs, dim = 0).unsqueeze(0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only
def load_state_dict(self, sd, strict = False):
return self.audio_tokenizer.load_state_dict(sd, strict = strict)
def state_dict(self):
return self.audio_tokenizer.state_dict()