mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-16 01:00:49 +08:00
final
This commit is contained in:
parent
64124224e7
commit
2ac8999287
@ -255,6 +255,8 @@ class AutoRegressiveGeneration:
|
||||
self.dtype = model.dtype
|
||||
|
||||
self.model = model
|
||||
self.model.generation_config = GenerationConfig.from_model_config(self.model.config)
|
||||
self.model.generation_config.cache_implementation = self.model.cache_implementation
|
||||
|
||||
text_config = self.model.cache_config
|
||||
self.cache_config = CacheConfig(
|
||||
@ -331,8 +333,6 @@ class AutoRegressiveGeneration:
|
||||
do_sample = do_sample,
|
||||
temperature = temperature)
|
||||
|
||||
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
||||
|
||||
generation_config, model_kwargs = self._prepare_generation_config(
|
||||
generation_config, **kwargs
|
||||
)
|
||||
|
||||
@ -428,14 +428,13 @@ class HiggsAudioModel(nn.Module):
|
||||
self.cache_config = kwargs["text_config"]
|
||||
self.hidden_dim = kwargs["text_config"]["hidden_size"]
|
||||
self.max_seq_len = kwargs["text_config"]["max_position_embeddings"]
|
||||
self.cache_implementation = "static"
|
||||
self.use_kv_buckets = kwargs.get("use_kv_buckets", False)
|
||||
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.config = kwargs
|
||||
|
||||
self.generation_config = GenerationConfig.from_model_config(kwargs)
|
||||
self.generation_config.cache_implementation = self.cache_implementation = "static"
|
||||
|
||||
self.audio_out_bos_token_id = 128013
|
||||
self.audio_eos_token_id = 128012
|
||||
|
||||
@ -174,8 +174,6 @@ class CreateChatMLSample:
|
||||
current_role = None
|
||||
collecting_system = False
|
||||
system_buffer = []
|
||||
collecting_instruction = False
|
||||
instruction_buffer = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
@ -198,26 +196,6 @@ class CreateChatMLSample:
|
||||
collecting_system = False
|
||||
continue
|
||||
|
||||
# generation instruction start
|
||||
if "<|generation_instruction_start|>" in line:
|
||||
collecting_instruction = True
|
||||
instruction_buffer = []
|
||||
continue
|
||||
|
||||
if collecting_instruction:
|
||||
if "<|generation_instruction_end|>" in line:
|
||||
instruction_text = "\n".join(instruction_buffer)
|
||||
# include both start and end tokens
|
||||
messages.append(Message(
|
||||
role="user",
|
||||
content=f"<|generation_instruction_start|>\n{instruction_text}\n<|generation_instruction_end|>"
|
||||
))
|
||||
instruction_buffer = []
|
||||
collecting_instruction = False
|
||||
else:
|
||||
instruction_buffer.append(line)
|
||||
continue
|
||||
|
||||
# speaker lines SPEAKER-0: text
|
||||
match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE)
|
||||
if match:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user