cache updates

This commit is contained in:
Yousef Rafat 2025-10-25 00:03:00 +03:00
parent 01b32521c5
commit 76fdb4bc9a
2 changed files with 108 additions and 71 deletions

View File

@ -7,10 +7,112 @@ import inspect
import warnings
from enum import Enum
from dataclasses import dataclass, fields
from transformers.cache_utils import StaticCache, DynamicCache, Cache
from typing import Optional, Union, Any
from abc import ABC, abstractmethod
import comfy.model_management
class CacheLayerParent(ABC):
def __init__(self):
self.keys, self.values = None, None
@abstractmethod
def lazy_init(self, key_states): ...
@abstractmethod
def update(self, key_states, value_states, cache_kwargs = None): ...
@abstractmethod
def get_seq_length(self) -> int: ...
@abstractmethod
def get_max_cache_shape(self) -> int: ...
def reset(self):
if self.keys is not None:
self.keys.zero_()
self.values.zero_()
class Cache:
def __init__(self):
self.layers = []
# TODO: offload logic
def get_seq_length(self, layer_idx = 0):
if layer_idx >= len(self.layers):
return 0
return self.layers[layer_idx].get_seq_length()
def get_max_cache_shape(self, layer_idx = 0):
if layer_idx >= len(self.layers): # for future dynamic cache impl.
return -1
return self.layers[layer_idx].get_max_cache_shape()
def reset(self):
for layer_idx in range(len(self.layers)):
self.layers[layer_idx].reset()
def update(self, key_states, value_states, layer_idx, cache_kwargs):
keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
return keys, values
@property
def max_cache_len(self):
values = [layer.max_cache_len for layer in self.layers]
return max(values)
class StaticLayer(CacheLayerParent):
def __init__(self, max_cache_len):
super().__init__()
self.max_cache_len = max_cache_len
def get_max_cache_shape(self):
return self.max_cache_len
def get_seq_length(self):
return (self.keys[0, 0].any(dim=-1)).sum() if self.keys is not None else 0
def lazy_init(self, key_states):
self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
self.dtype, self.device = key_states.dtype, key_states.device
self.keys = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.values = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
def update(self, key_states, value_states, cache_kwargs = None):
if self.keys is None:
self.lazy_init(key_states)
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
cache_position = (
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
)
try:
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# mps fallback
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states
return self.keys, self.values
class StaticCache(Cache):
def __init__(self, config, max_cache_len):
self.layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)]
# TODO
class DynamicCache(Cache):
pass
NEED_SETUP_CACHE_CLASSES_MAPPING = { "static": StaticCache }
class TopKLogits:
@ -240,13 +342,9 @@ class AutoRegressiveGeneration:
self.model.cache_config = self.cache_config
self.kv_caches = {
length: StaticCache(
config=self.cache_config,
max_batch_size = self.cache_config.max_batch,
max_cache_len=length,
device=self.model.device,
dtype=self.model.dtype,
)
for length in sorted(kv_cache_lengths)
@ -386,11 +484,11 @@ class AutoRegressiveGeneration:
def _get_cache(
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
) -> Cache:
):
assert cache_implementation == "static", f"Only 'static' cache is supported, got {cache_implementation}"
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
if hasattr(self.model, "_cache"):
cache_to_check = self.model._cache
@ -405,14 +503,9 @@ class AutoRegressiveGeneration:
)
if need_new_cache:
cache_dtype = getattr(self.model.config, "_pre_quantization_dtype", self.dtype)
cache_kwargs = {
"config": self.cache_config,
"max_batch_size": batch_size,
"max_cache_len": max_cache_len,
"dtype": cache_dtype,
"device": device,
"layer_device_map": None,
}
self.model._cache = cache_cls(**cache_kwargs)

View File

@ -4,8 +4,7 @@ from comfy.text_encoders.llama import (
RMSNorm, MLP, Attention, LlamaRoPE, Llama2Config
)
from comfy.autoregressive_sampling import GenerationConfig, apply_logits_processing, check_stopping_criteria
from transformers.cache_utils import Cache, StaticCache, DynamicCache
from comfy.autoregressive_sampling import GenerationConfig, apply_logits_processing, check_stopping_criteria, Cache, StaticCache, DynamicCache
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.ldm.modules.attention import optimized_attention
from .preprocess import _ceil_to_nearest
@ -982,24 +981,9 @@ class HiggsAudioModel(nn.Module):
from_layer = from_cache.layers[i]
to_layer = to_cache.layers[i]
# lazy init
if getattr(to_layer, "keys", None) is None:
to_layer.keys = torch.zeros(
(self.cache_config.max_batch, self.cache_config.num_key_value_heads,
to_cache.get_max_cache_shape(),
self.cache_config.head_dim),
device=self.device, dtype=self.dtype
)
if getattr(to_layer, "values", None) is None:
to_layer.values = torch.zeros(
(self.cache_config.max_batch, self.cache_config.num_key_value_heads,
to_cache.get_max_cache_shape(),
self.cache_config.head_dim),
device=self.device, dtype=self.dtype
)
seq_len = from_cache_size
to_layer.lazy_init(from_layer.keys)
to_layer.keys[:, :, :seq_len, :] = from_layer.keys
to_layer.values[:, :, :seq_len, :] = from_layer.values
@ -1266,43 +1250,3 @@ class HiggsAudioModel(nn.Module):
)
return generation_config
@torch.inference_mode()
def capture_model(self, past_key_values: list[Union[Cache, List[torch.FloatTensor]]]) -> None:
for past_key_value in past_key_values:
kv_cache_length = past_key_value.get_max_cache_shape()
for is_decoding_audio_token in [True, False]:
runner = None#CUDAGraphRunner(self._forward_core)
batch_size = 1
hidden_dim = self.config["hidden_size"]
hidden_states = torch.zeros(
(batch_size, 1, hidden_dim), dtype=self.dtype, device=self.device
)
causal_mask = torch.ones(
(batch_size, 1, 1, kv_cache_length), dtype=self.dtype, device=self.device
)
position_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=self.device)
audio_discrete_codes_mask = torch.tensor(
[[is_decoding_audio_token]], dtype=torch.bool, device=self.device
)
cache_position = torch.tensor([kv_cache_length - 1], dtype=torch.long, device=self.device)
audio_attention_mask = torch.ones_like(causal_mask)
fast_forward_attention_mask = torch.ones_like(causal_mask)
runner.capture(
hidden_states=hidden_states,
causal_mask=causal_mask,
position_ids=position_ids,
audio_discrete_codes_mask=audio_discrete_codes_mask,
cache_position=cache_position,
past_key_values=past_key_value,
use_cache=True,
audio_attention_mask=audio_attention_mask,
fast_forward_attention_mask=fast_forward_attention_mask,
is_decoding_audio_token=is_decoding_audio_token,
is_using_cuda_graph=True,
)
self.decode_graph_runners[kv_cache_length][is_decoding_audio_token] = runner