This commit is contained in:
Yousef Rafat 2025-10-24 18:22:02 +03:00
parent faa866733d
commit 01b32521c5

View File

@ -14,7 +14,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from collections import defaultdict, OrderedDict from collections import OrderedDict
from typing import Optional, Tuple, Union, List from typing import Optional, Tuple, Union, List
class GenerationMode(Enum): class GenerationMode(Enum):
@ -1272,7 +1272,7 @@ class HiggsAudioModel(nn.Module):
for past_key_value in past_key_values: for past_key_value in past_key_values:
kv_cache_length = past_key_value.get_max_cache_shape() kv_cache_length = past_key_value.get_max_cache_shape()
for is_decoding_audio_token in [True, False]: for is_decoding_audio_token in [True, False]:
runner = CUDAGraphRunner(self._forward_core) runner = None#CUDAGraphRunner(self._forward_core)
batch_size = 1 batch_size = 1
hidden_dim = self.config["hidden_size"] hidden_dim = self.config["hidden_size"]