mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 16:50:57 +08:00
cache updates
This commit is contained in:
parent
01b32521c5
commit
76fdb4bc9a
@ -7,10 +7,112 @@ import inspect
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, fields
|
||||
from transformers.cache_utils import StaticCache, DynamicCache, Cache
|
||||
from typing import Optional, Union, Any
|
||||
from abc import ABC, abstractmethod
|
||||
import comfy.model_management
|
||||
|
||||
class CacheLayerParent(ABC):
|
||||
def __init__(self):
|
||||
self.keys, self.values = None, None
|
||||
|
||||
@abstractmethod
|
||||
def lazy_init(self, key_states): ...
|
||||
|
||||
@abstractmethod
|
||||
def update(self, key_states, value_states, cache_kwargs = None): ...
|
||||
|
||||
@abstractmethod
|
||||
def get_seq_length(self) -> int: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_max_cache_shape(self) -> int: ...
|
||||
|
||||
def reset(self):
|
||||
if self.keys is not None:
|
||||
self.keys.zero_()
|
||||
self.values.zero_()
|
||||
|
||||
class Cache:
|
||||
def __init__(self):
|
||||
self.layers = []
|
||||
# TODO: offload logic
|
||||
def get_seq_length(self, layer_idx = 0):
|
||||
if layer_idx >= len(self.layers):
|
||||
return 0
|
||||
return self.layers[layer_idx].get_seq_length()
|
||||
|
||||
def get_max_cache_shape(self, layer_idx = 0):
|
||||
if layer_idx >= len(self.layers): # for future dynamic cache impl.
|
||||
return -1
|
||||
return self.layers[layer_idx].get_max_cache_shape()
|
||||
def reset(self):
|
||||
for layer_idx in range(len(self.layers)):
|
||||
self.layers[layer_idx].reset()
|
||||
|
||||
def update(self, key_states, value_states, layer_idx, cache_kwargs):
|
||||
keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
|
||||
return keys, values
|
||||
|
||||
@property
|
||||
def max_cache_len(self):
|
||||
values = [layer.max_cache_len for layer in self.layers]
|
||||
return max(values)
|
||||
|
||||
class StaticLayer(CacheLayerParent):
|
||||
def __init__(self, max_cache_len):
|
||||
super().__init__()
|
||||
self.max_cache_len = max_cache_len
|
||||
|
||||
def get_max_cache_shape(self):
|
||||
return self.max_cache_len
|
||||
|
||||
def get_seq_length(self):
|
||||
return (self.keys[0, 0].any(dim=-1)).sum() if self.keys is not None else 0
|
||||
|
||||
def lazy_init(self, key_states):
|
||||
|
||||
self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
|
||||
self.dtype, self.device = key_states.dtype, key_states.device
|
||||
|
||||
self.keys = torch.zeros(
|
||||
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.values = torch.zeros(
|
||||
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def update(self, key_states, value_states, cache_kwargs = None):
|
||||
|
||||
if self.keys is None:
|
||||
self.lazy_init(key_states)
|
||||
|
||||
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
|
||||
cache_position = (
|
||||
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
|
||||
)
|
||||
|
||||
try:
|
||||
self.keys.index_copy_(2, cache_position, key_states)
|
||||
self.values.index_copy_(2, cache_position, value_states)
|
||||
except NotImplementedError:
|
||||
# mps fallback
|
||||
self.keys[:, :, cache_position] = key_states
|
||||
self.values[:, :, cache_position] = value_states
|
||||
return self.keys, self.values
|
||||
|
||||
class StaticCache(Cache):
|
||||
def __init__(self, config, max_cache_len):
|
||||
self.layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)]
|
||||
|
||||
# TODO
|
||||
class DynamicCache(Cache):
|
||||
pass
|
||||
|
||||
NEED_SETUP_CACHE_CLASSES_MAPPING = { "static": StaticCache }
|
||||
|
||||
class TopKLogits:
|
||||
@ -240,13 +342,9 @@ class AutoRegressiveGeneration:
|
||||
self.model.cache_config = self.cache_config
|
||||
|
||||
self.kv_caches = {
|
||||
|
||||
length: StaticCache(
|
||||
config=self.cache_config,
|
||||
max_batch_size = self.cache_config.max_batch,
|
||||
max_cache_len=length,
|
||||
device=self.model.device,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
for length in sorted(kv_cache_lengths)
|
||||
|
||||
@ -386,11 +484,11 @@ class AutoRegressiveGeneration:
|
||||
|
||||
def _get_cache(
|
||||
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
|
||||
) -> Cache:
|
||||
):
|
||||
|
||||
assert cache_implementation == "static", f"Only 'static' cache is supported, got {cache_implementation}"
|
||||
|
||||
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
|
||||
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
|
||||
|
||||
if hasattr(self.model, "_cache"):
|
||||
cache_to_check = self.model._cache
|
||||
@ -405,14 +503,9 @@ class AutoRegressiveGeneration:
|
||||
)
|
||||
|
||||
if need_new_cache:
|
||||
cache_dtype = getattr(self.model.config, "_pre_quantization_dtype", self.dtype)
|
||||
cache_kwargs = {
|
||||
"config": self.cache_config,
|
||||
"max_batch_size": batch_size,
|
||||
"max_cache_len": max_cache_len,
|
||||
"dtype": cache_dtype,
|
||||
"device": device,
|
||||
"layer_device_map": None,
|
||||
}
|
||||
self.model._cache = cache_cls(**cache_kwargs)
|
||||
|
||||
|
||||
@ -4,8 +4,7 @@ from comfy.text_encoders.llama import (
|
||||
RMSNorm, MLP, Attention, LlamaRoPE, Llama2Config
|
||||
)
|
||||
|
||||
from comfy.autoregressive_sampling import GenerationConfig, apply_logits_processing, check_stopping_criteria
|
||||
from transformers.cache_utils import Cache, StaticCache, DynamicCache
|
||||
from comfy.autoregressive_sampling import GenerationConfig, apply_logits_processing, check_stopping_criteria, Cache, StaticCache, DynamicCache
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from .preprocess import _ceil_to_nearest
|
||||
@ -982,24 +981,9 @@ class HiggsAudioModel(nn.Module):
|
||||
from_layer = from_cache.layers[i]
|
||||
to_layer = to_cache.layers[i]
|
||||
|
||||
# lazy init
|
||||
if getattr(to_layer, "keys", None) is None:
|
||||
to_layer.keys = torch.zeros(
|
||||
(self.cache_config.max_batch, self.cache_config.num_key_value_heads,
|
||||
to_cache.get_max_cache_shape(),
|
||||
self.cache_config.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
|
||||
if getattr(to_layer, "values", None) is None:
|
||||
to_layer.values = torch.zeros(
|
||||
(self.cache_config.max_batch, self.cache_config.num_key_value_heads,
|
||||
to_cache.get_max_cache_shape(),
|
||||
self.cache_config.head_dim),
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
|
||||
seq_len = from_cache_size
|
||||
|
||||
to_layer.lazy_init(from_layer.keys)
|
||||
to_layer.keys[:, :, :seq_len, :] = from_layer.keys
|
||||
to_layer.values[:, :, :seq_len, :] = from_layer.values
|
||||
|
||||
@ -1266,43 +1250,3 @@ class HiggsAudioModel(nn.Module):
|
||||
)
|
||||
|
||||
return generation_config
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, past_key_values: list[Union[Cache, List[torch.FloatTensor]]]) -> None:
|
||||
for past_key_value in past_key_values:
|
||||
kv_cache_length = past_key_value.get_max_cache_shape()
|
||||
for is_decoding_audio_token in [True, False]:
|
||||
runner = None#CUDAGraphRunner(self._forward_core)
|
||||
|
||||
batch_size = 1
|
||||
hidden_dim = self.config["hidden_size"]
|
||||
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size, 1, hidden_dim), dtype=self.dtype, device=self.device
|
||||
)
|
||||
causal_mask = torch.ones(
|
||||
(batch_size, 1, 1, kv_cache_length), dtype=self.dtype, device=self.device
|
||||
)
|
||||
position_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=self.device)
|
||||
audio_discrete_codes_mask = torch.tensor(
|
||||
[[is_decoding_audio_token]], dtype=torch.bool, device=self.device
|
||||
)
|
||||
cache_position = torch.tensor([kv_cache_length - 1], dtype=torch.long, device=self.device)
|
||||
audio_attention_mask = torch.ones_like(causal_mask)
|
||||
fast_forward_attention_mask = torch.ones_like(causal_mask)
|
||||
|
||||
runner.capture(
|
||||
hidden_states=hidden_states,
|
||||
causal_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
audio_discrete_codes_mask=audio_discrete_codes_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_value,
|
||||
use_cache=True,
|
||||
audio_attention_mask=audio_attention_mask,
|
||||
fast_forward_attention_mask=fast_forward_attention_mask,
|
||||
is_decoding_audio_token=is_decoding_audio_token,
|
||||
is_using_cuda_graph=True,
|
||||
)
|
||||
|
||||
self.decode_graph_runners[kv_cache_length][is_decoding_audio_token] = runner
|
||||
|
||||
Loading…
Reference in New Issue
Block a user