diff --git a/comfy/autoregressive_sampling.py b/comfy/autoregressive_sampling.py index 8cecd7428..c8e8aa08e 100644 --- a/comfy/autoregressive_sampling.py +++ b/comfy/autoregressive_sampling.py @@ -7,10 +7,112 @@ import inspect import warnings from enum import Enum from dataclasses import dataclass, fields -from transformers.cache_utils import StaticCache, DynamicCache, Cache from typing import Optional, Union, Any +from abc import ABC, abstractmethod import comfy.model_management +class CacheLayerParent(ABC): + def __init__(self): + self.keys, self.values = None, None + + @abstractmethod + def lazy_init(self, key_states): ... + + @abstractmethod + def update(self, key_states, value_states, cache_kwargs = None): ... + + @abstractmethod + def get_seq_length(self) -> int: ... + + @abstractmethod + def get_max_cache_shape(self) -> int: ... + + def reset(self): + if self.keys is not None: + self.keys.zero_() + self.values.zero_() + +class Cache: + def __init__(self): + self.layers = [] + # TODO: offload logic + def get_seq_length(self, layer_idx = 0): + if layer_idx >= len(self.layers): + return 0 + return self.layers[layer_idx].get_seq_length() + + def get_max_cache_shape(self, layer_idx = 0): + if layer_idx >= len(self.layers): # for future dynamic cache impl. + return -1 + return self.layers[layer_idx].get_max_cache_shape() + def reset(self): + for layer_idx in range(len(self.layers)): + self.layers[layer_idx].reset() + + def update(self, key_states, value_states, layer_idx, cache_kwargs): + keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs) + return keys, values + + @property + def max_cache_len(self): + values = [layer.max_cache_len for layer in self.layers] + return max(values) + +class StaticLayer(CacheLayerParent): + def __init__(self, max_cache_len): + super().__init__() + self.max_cache_len = max_cache_len + + def get_max_cache_shape(self): + return self.max_cache_len + + def get_seq_length(self): + return (self.keys[0, 0].any(dim=-1)).sum() if self.keys is not None else 0 + + def lazy_init(self, key_states): + + self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape + self.dtype, self.device = key_states.dtype, key_states.device + + self.keys = torch.zeros( + (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim), + dtype=self.dtype, + device=self.device, + ) + + self.values = torch.zeros( + (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim), + dtype=self.dtype, + device=self.device, + ) + + def update(self, key_states, value_states, cache_kwargs = None): + + if self.keys is None: + self.lazy_init(key_states) + + cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None + cache_position = ( + cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device) + ) + + try: + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # mps fallback + self.keys[:, :, cache_position] = key_states + self.values[:, :, cache_position] = value_states + return self.keys, self.values + +class StaticCache(Cache): + def __init__(self, config, max_cache_len): + self.layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)] + +# TODO +class DynamicCache(Cache): + pass + NEED_SETUP_CACHE_CLASSES_MAPPING = { "static": StaticCache } class TopKLogits: @@ -240,13 +342,9 @@ class AutoRegressiveGeneration: self.model.cache_config = self.cache_config self.kv_caches = { - length: StaticCache( config=self.cache_config, - max_batch_size = self.cache_config.max_batch, max_cache_len=length, - device=self.model.device, - dtype=self.model.dtype, ) for length in sorted(kv_cache_lengths) @@ -386,11 +484,11 @@ class AutoRegressiveGeneration: def _get_cache( self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs - ) -> Cache: + ): assert cache_implementation == "static", f"Only 'static' cache is supported, got {cache_implementation}" - cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] + cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] if hasattr(self.model, "_cache"): cache_to_check = self.model._cache @@ -405,14 +503,9 @@ class AutoRegressiveGeneration: ) if need_new_cache: - cache_dtype = getattr(self.model.config, "_pre_quantization_dtype", self.dtype) cache_kwargs = { "config": self.cache_config, - "max_batch_size": batch_size, "max_cache_len": max_cache_len, - "dtype": cache_dtype, - "device": device, - "layer_device_map": None, } self.model._cache = cache_cls(**cache_kwargs) diff --git a/comfy/ldm/higgsv2/model.py b/comfy/ldm/higgsv2/model.py index a8a811bab..81d10abad 100644 --- a/comfy/ldm/higgsv2/model.py +++ b/comfy/ldm/higgsv2/model.py @@ -4,8 +4,7 @@ from comfy.text_encoders.llama import ( RMSNorm, MLP, Attention, LlamaRoPE, Llama2Config ) -from comfy.autoregressive_sampling import GenerationConfig, apply_logits_processing, check_stopping_criteria -from transformers.cache_utils import Cache, StaticCache, DynamicCache +from comfy.autoregressive_sampling import GenerationConfig, apply_logits_processing, check_stopping_criteria, Cache, StaticCache, DynamicCache from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention from .preprocess import _ceil_to_nearest @@ -982,24 +981,9 @@ class HiggsAudioModel(nn.Module): from_layer = from_cache.layers[i] to_layer = to_cache.layers[i] - # lazy init - if getattr(to_layer, "keys", None) is None: - to_layer.keys = torch.zeros( - (self.cache_config.max_batch, self.cache_config.num_key_value_heads, - to_cache.get_max_cache_shape(), - self.cache_config.head_dim), - device=self.device, dtype=self.dtype - ) - - if getattr(to_layer, "values", None) is None: - to_layer.values = torch.zeros( - (self.cache_config.max_batch, self.cache_config.num_key_value_heads, - to_cache.get_max_cache_shape(), - self.cache_config.head_dim), - device=self.device, dtype=self.dtype - ) - seq_len = from_cache_size + + to_layer.lazy_init(from_layer.keys) to_layer.keys[:, :, :seq_len, :] = from_layer.keys to_layer.values[:, :, :seq_len, :] = from_layer.values @@ -1266,43 +1250,3 @@ class HiggsAudioModel(nn.Module): ) return generation_config - - @torch.inference_mode() - def capture_model(self, past_key_values: list[Union[Cache, List[torch.FloatTensor]]]) -> None: - for past_key_value in past_key_values: - kv_cache_length = past_key_value.get_max_cache_shape() - for is_decoding_audio_token in [True, False]: - runner = None#CUDAGraphRunner(self._forward_core) - - batch_size = 1 - hidden_dim = self.config["hidden_size"] - - hidden_states = torch.zeros( - (batch_size, 1, hidden_dim), dtype=self.dtype, device=self.device - ) - causal_mask = torch.ones( - (batch_size, 1, 1, kv_cache_length), dtype=self.dtype, device=self.device - ) - position_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=self.device) - audio_discrete_codes_mask = torch.tensor( - [[is_decoding_audio_token]], dtype=torch.bool, device=self.device - ) - cache_position = torch.tensor([kv_cache_length - 1], dtype=torch.long, device=self.device) - audio_attention_mask = torch.ones_like(causal_mask) - fast_forward_attention_mask = torch.ones_like(causal_mask) - - runner.capture( - hidden_states=hidden_states, - causal_mask=causal_mask, - position_ids=position_ids, - audio_discrete_codes_mask=audio_discrete_codes_mask, - cache_position=cache_position, - past_key_values=past_key_value, - use_cache=True, - audio_attention_mask=audio_attention_mask, - fast_forward_attention_mask=fast_forward_attention_mask, - is_decoding_audio_token=is_decoding_audio_token, - is_using_cuda_graph=True, - ) - - self.decode_graph_runners[kv_cache_length][is_decoding_audio_token] = runner