This commit is contained in:
Yousef Rafat 2025-09-09 23:04:03 +03:00
parent 64124224e7
commit 2ac8999287
3 changed files with 3 additions and 26 deletions

View File

@ -255,6 +255,8 @@ class AutoRegressiveGeneration:
self.dtype = model.dtype
self.model = model
self.model.generation_config = GenerationConfig.from_model_config(self.model.config)
self.model.generation_config.cache_implementation = self.model.cache_implementation
text_config = self.model.cache_config
self.cache_config = CacheConfig(
@ -331,8 +333,6 @@ class AutoRegressiveGeneration:
do_sample = do_sample,
temperature = temperature)
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
generation_config, model_kwargs = self._prepare_generation_config(
generation_config, **kwargs
)

View File

@ -428,14 +428,13 @@ class HiggsAudioModel(nn.Module):
self.cache_config = kwargs["text_config"]
self.hidden_dim = kwargs["text_config"]["hidden_size"]
self.max_seq_len = kwargs["text_config"]["max_position_embeddings"]
self.cache_implementation = "static"
self.use_kv_buckets = kwargs.get("use_kv_buckets", False)
self.dtype = dtype
self.device = device
self.config = kwargs
self.generation_config = GenerationConfig.from_model_config(kwargs)
self.generation_config.cache_implementation = self.cache_implementation = "static"
self.audio_out_bos_token_id = 128013
self.audio_eos_token_id = 128012

View File

@ -174,8 +174,6 @@ class CreateChatMLSample:
current_role = None
collecting_system = False
system_buffer = []
collecting_instruction = False
instruction_buffer = []
for line in lines:
line = line.strip()
@ -198,26 +196,6 @@ class CreateChatMLSample:
collecting_system = False
continue
# generation instruction start
if "<|generation_instruction_start|>" in line:
collecting_instruction = True
instruction_buffer = []
continue
if collecting_instruction:
if "<|generation_instruction_end|>" in line:
instruction_text = "\n".join(instruction_buffer)
# include both start and end tokens
messages.append(Message(
role="user",
content=f"<|generation_instruction_start|>\n{instruction_text}\n<|generation_instruction_end|>"
))
instruction_buffer = []
collecting_instruction = False
else:
instruction_buffer.append(line)
continue
# speaker lines SPEAKER-0: text
match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE)
if match: