ruff checks

This commit is contained in:
Yousef Rafat 2025-10-25 00:05:35 +03:00
parent 76fdb4bc9a
commit 907ed37fac

View File

@ -14,10 +14,10 @@ import comfy.model_management
class CacheLayerParent(ABC):
def __init__(self):
self.keys, self.values = None, None
@abstractmethod
def lazy_init(self, key_states): ...
@abstractmethod
def update(self, key_states, value_states, cache_kwargs = None): ...
@ -40,7 +40,7 @@ class Cache:
if layer_idx >= len(self.layers):
return 0
return self.layers[layer_idx].get_seq_length()
def get_max_cache_shape(self, layer_idx = 0):
if layer_idx >= len(self.layers): # for future dynamic cache impl.
return -1
@ -52,7 +52,7 @@ class Cache:
def update(self, key_states, value_states, layer_idx, cache_kwargs):
keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
return keys, values
@property
def max_cache_len(self):
values = [layer.max_cache_len for layer in self.layers]
@ -73,7 +73,7 @@ class StaticLayer(CacheLayerParent):
self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
self.dtype, self.device = key_states.dtype, key_states.device
self.keys = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
@ -87,7 +87,7 @@ class StaticLayer(CacheLayerParent):
)
def update(self, key_states, value_states, cache_kwargs = None):
if self.keys is None:
self.lazy_init(key_states)
@ -104,7 +104,7 @@ class StaticLayer(CacheLayerParent):
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states
return self.keys, self.values
class StaticCache(Cache):
def __init__(self, config, max_cache_len):
self.layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)]