This commit is contained in:
Yousef Rafat 2025-10-24 18:22:02 +03:00
parent faa866733d
commit 01b32521c5

View File

@ -14,7 +14,7 @@ import torch
import torch.nn as nn
from enum import Enum
from dataclasses import dataclass
from collections import defaultdict, OrderedDict
from collections import OrderedDict
from typing import Optional, Tuple, Union, List
class GenerationMode(Enum):
@ -1272,7 +1272,7 @@ class HiggsAudioModel(nn.Module):
for past_key_value in past_key_values:
kv_cache_length = past_key_value.get_max_cache_shape()
for is_decoding_audio_token in [True, False]:
runner = CUDAGraphRunner(self._forward_core)
runner = None#CUDAGraphRunner(self._forward_core)
batch_size = 1
hidden_dim = self.config["hidden_size"]