mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-07 12:02:37 +08:00
forgot cudagraph import in higgsv2 model.py
This commit is contained in:
parent
b80cbecf8f
commit
faa866733d
@ -8,7 +8,6 @@ from comfy.autoregressive_sampling import GenerationConfig, apply_logits_process
|
||||
from transformers.cache_utils import Cache, StaticCache, DynamicCache
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from .cuda_graph_runner import CUDAGraphRunner
|
||||
from .preprocess import _ceil_to_nearest
|
||||
|
||||
import torch
|
||||
@ -475,8 +474,8 @@ class HiggsAudioModel(nn.Module):
|
||||
raise NotImplementedError(f"Audio adapter type {kwargs['audio_adapter_type']} not implemented.")
|
||||
|
||||
self.num_activation_checkpointing_layers = len(self.layers)
|
||||
|
||||
self.decode_graph_runners = defaultdict(dict[bool, CUDAGraphRunner])
|
||||
self.decode_graph_runners = []
|
||||
#self.decode_graph_runners = defaultdict(dict[bool, CUDAGraphRunner])
|
||||
self.norm = RMSNorm(
|
||||
kwargs["text_config"]["hidden_size"], eps = kwargs["text_config"]["rms_norm_eps"]
|
||||
)
|
||||
@ -886,9 +885,9 @@ class HiggsAudioModel(nn.Module):
|
||||
|
||||
if (
|
||||
past_key_values is not None
|
||||
and is_using_cuda_graphs
|
||||
and past_key_values.get_max_cache_shape() in self.decode_graph_runners
|
||||
and (input_ids.shape[-1] == 1)
|
||||
and is_using_cuda_graphs
|
||||
):
|
||||
_forward_core = self.decode_graph_runners[past_key_values.get_max_cache_shape()][is_decoding_audio_token]
|
||||
local_cuda_graph = True
|
||||
|
||||
Loading…
Reference in New Issue
Block a user