forgot cudagraph import in higgsv2 model.py

This commit is contained in:
Yousef Rafat 2025-10-24 18:15:06 +03:00
parent b80cbecf8f
commit faa866733d

View File

@ -8,7 +8,6 @@ from comfy.autoregressive_sampling import GenerationConfig, apply_logits_process
from transformers.cache_utils import Cache, StaticCache, DynamicCache from transformers.cache_utils import Cache, StaticCache, DynamicCache
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
from .cuda_graph_runner import CUDAGraphRunner
from .preprocess import _ceil_to_nearest from .preprocess import _ceil_to_nearest
import torch import torch
@ -475,8 +474,8 @@ class HiggsAudioModel(nn.Module):
raise NotImplementedError(f"Audio adapter type {kwargs['audio_adapter_type']} not implemented.") raise NotImplementedError(f"Audio adapter type {kwargs['audio_adapter_type']} not implemented.")
self.num_activation_checkpointing_layers = len(self.layers) self.num_activation_checkpointing_layers = len(self.layers)
self.decode_graph_runners = []
self.decode_graph_runners = defaultdict(dict[bool, CUDAGraphRunner]) #self.decode_graph_runners = defaultdict(dict[bool, CUDAGraphRunner])
self.norm = RMSNorm( self.norm = RMSNorm(
kwargs["text_config"]["hidden_size"], eps = kwargs["text_config"]["rms_norm_eps"] kwargs["text_config"]["hidden_size"], eps = kwargs["text_config"]["rms_norm_eps"]
) )
@ -886,9 +885,9 @@ class HiggsAudioModel(nn.Module):
if ( if (
past_key_values is not None past_key_values is not None
and is_using_cuda_graphs
and past_key_values.get_max_cache_shape() in self.decode_graph_runners and past_key_values.get_max_cache_shape() in self.decode_graph_runners
and (input_ids.shape[-1] == 1) and (input_ids.shape[-1] == 1)
and is_using_cuda_graphs
): ):
_forward_core = self.decode_graph_runners[past_key_values.get_max_cache_shape()][is_decoding_audio_token] _forward_core = self.decode_graph_runners[past_key_values.get_max_cache_shape()][is_decoding_audio_token]
local_cuda_graph = True local_cuda_graph = True