mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 06:10:15 +08:00
final
This commit is contained in:
parent
64124224e7
commit
2ac8999287
@ -255,6 +255,8 @@ class AutoRegressiveGeneration:
|
|||||||
self.dtype = model.dtype
|
self.dtype = model.dtype
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.model.generation_config = GenerationConfig.from_model_config(self.model.config)
|
||||||
|
self.model.generation_config.cache_implementation = self.model.cache_implementation
|
||||||
|
|
||||||
text_config = self.model.cache_config
|
text_config = self.model.cache_config
|
||||||
self.cache_config = CacheConfig(
|
self.cache_config = CacheConfig(
|
||||||
@ -331,8 +333,6 @@ class AutoRegressiveGeneration:
|
|||||||
do_sample = do_sample,
|
do_sample = do_sample,
|
||||||
temperature = temperature)
|
temperature = temperature)
|
||||||
|
|
||||||
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
|
|
||||||
|
|
||||||
generation_config, model_kwargs = self._prepare_generation_config(
|
generation_config, model_kwargs = self._prepare_generation_config(
|
||||||
generation_config, **kwargs
|
generation_config, **kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@ -428,14 +428,13 @@ class HiggsAudioModel(nn.Module):
|
|||||||
self.cache_config = kwargs["text_config"]
|
self.cache_config = kwargs["text_config"]
|
||||||
self.hidden_dim = kwargs["text_config"]["hidden_size"]
|
self.hidden_dim = kwargs["text_config"]["hidden_size"]
|
||||||
self.max_seq_len = kwargs["text_config"]["max_position_embeddings"]
|
self.max_seq_len = kwargs["text_config"]["max_position_embeddings"]
|
||||||
|
self.cache_implementation = "static"
|
||||||
self.use_kv_buckets = kwargs.get("use_kv_buckets", False)
|
self.use_kv_buckets = kwargs.get("use_kv_buckets", False)
|
||||||
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
self.config = kwargs
|
self.config = kwargs
|
||||||
|
|
||||||
self.generation_config = GenerationConfig.from_model_config(kwargs)
|
|
||||||
self.generation_config.cache_implementation = self.cache_implementation = "static"
|
|
||||||
|
|
||||||
self.audio_out_bos_token_id = 128013
|
self.audio_out_bos_token_id = 128013
|
||||||
self.audio_eos_token_id = 128012
|
self.audio_eos_token_id = 128012
|
||||||
|
|||||||
@ -174,8 +174,6 @@ class CreateChatMLSample:
|
|||||||
current_role = None
|
current_role = None
|
||||||
collecting_system = False
|
collecting_system = False
|
||||||
system_buffer = []
|
system_buffer = []
|
||||||
collecting_instruction = False
|
|
||||||
instruction_buffer = []
|
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
@ -198,26 +196,6 @@ class CreateChatMLSample:
|
|||||||
collecting_system = False
|
collecting_system = False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# generation instruction start
|
|
||||||
if "<|generation_instruction_start|>" in line:
|
|
||||||
collecting_instruction = True
|
|
||||||
instruction_buffer = []
|
|
||||||
continue
|
|
||||||
|
|
||||||
if collecting_instruction:
|
|
||||||
if "<|generation_instruction_end|>" in line:
|
|
||||||
instruction_text = "\n".join(instruction_buffer)
|
|
||||||
# include both start and end tokens
|
|
||||||
messages.append(Message(
|
|
||||||
role="user",
|
|
||||||
content=f"<|generation_instruction_start|>\n{instruction_text}\n<|generation_instruction_end|>"
|
|
||||||
))
|
|
||||||
instruction_buffer = []
|
|
||||||
collecting_instruction = False
|
|
||||||
else:
|
|
||||||
instruction_buffer.append(line)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# speaker lines SPEAKER-0: text
|
# speaker lines SPEAKER-0: text
|
||||||
match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE)
|
match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE)
|
||||||
if match:
|
if match:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user