mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 15:20:25 +08:00
forgot cudagraph import in higgsv2 model.py
This commit is contained in:
parent
b80cbecf8f
commit
faa866733d
@ -8,7 +8,6 @@ from comfy.autoregressive_sampling import GenerationConfig, apply_logits_process
|
|||||||
from transformers.cache_utils import Cache, StaticCache, DynamicCache
|
from transformers.cache_utils import Cache, StaticCache, DynamicCache
|
||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from .cuda_graph_runner import CUDAGraphRunner
|
|
||||||
from .preprocess import _ceil_to_nearest
|
from .preprocess import _ceil_to_nearest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -475,8 +474,8 @@ class HiggsAudioModel(nn.Module):
|
|||||||
raise NotImplementedError(f"Audio adapter type {kwargs['audio_adapter_type']} not implemented.")
|
raise NotImplementedError(f"Audio adapter type {kwargs['audio_adapter_type']} not implemented.")
|
||||||
|
|
||||||
self.num_activation_checkpointing_layers = len(self.layers)
|
self.num_activation_checkpointing_layers = len(self.layers)
|
||||||
|
self.decode_graph_runners = []
|
||||||
self.decode_graph_runners = defaultdict(dict[bool, CUDAGraphRunner])
|
#self.decode_graph_runners = defaultdict(dict[bool, CUDAGraphRunner])
|
||||||
self.norm = RMSNorm(
|
self.norm = RMSNorm(
|
||||||
kwargs["text_config"]["hidden_size"], eps = kwargs["text_config"]["rms_norm_eps"]
|
kwargs["text_config"]["hidden_size"], eps = kwargs["text_config"]["rms_norm_eps"]
|
||||||
)
|
)
|
||||||
@ -886,9 +885,9 @@ class HiggsAudioModel(nn.Module):
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
past_key_values is not None
|
past_key_values is not None
|
||||||
|
and is_using_cuda_graphs
|
||||||
and past_key_values.get_max_cache_shape() in self.decode_graph_runners
|
and past_key_values.get_max_cache_shape() in self.decode_graph_runners
|
||||||
and (input_ids.shape[-1] == 1)
|
and (input_ids.shape[-1] == 1)
|
||||||
and is_using_cuda_graphs
|
|
||||||
):
|
):
|
||||||
_forward_core = self.decode_graph_runners[past_key_values.get_max_cache_shape()][is_decoding_audio_token]
|
_forward_core = self.decode_graph_runners[past_key_values.get_max_cache_shape()][is_decoding_audio_token]
|
||||||
local_cuda_graph = True
|
local_cuda_graph = True
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user