ruff checks

This commit is contained in:
Yousef Rafat 2025-10-25 00:05:35 +03:00
parent 76fdb4bc9a
commit 907ed37fac

View File

@ -14,10 +14,10 @@ import comfy.model_management
class CacheLayerParent(ABC): class CacheLayerParent(ABC):
def __init__(self): def __init__(self):
self.keys, self.values = None, None self.keys, self.values = None, None
@abstractmethod @abstractmethod
def lazy_init(self, key_states): ... def lazy_init(self, key_states): ...
@abstractmethod @abstractmethod
def update(self, key_states, value_states, cache_kwargs = None): ... def update(self, key_states, value_states, cache_kwargs = None): ...
@ -40,7 +40,7 @@ class Cache:
if layer_idx >= len(self.layers): if layer_idx >= len(self.layers):
return 0 return 0
return self.layers[layer_idx].get_seq_length() return self.layers[layer_idx].get_seq_length()
def get_max_cache_shape(self, layer_idx = 0): def get_max_cache_shape(self, layer_idx = 0):
if layer_idx >= len(self.layers): # for future dynamic cache impl. if layer_idx >= len(self.layers): # for future dynamic cache impl.
return -1 return -1
@ -52,7 +52,7 @@ class Cache:
def update(self, key_states, value_states, layer_idx, cache_kwargs): def update(self, key_states, value_states, layer_idx, cache_kwargs):
keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs) keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
return keys, values return keys, values
@property @property
def max_cache_len(self): def max_cache_len(self):
values = [layer.max_cache_len for layer in self.layers] values = [layer.max_cache_len for layer in self.layers]
@ -73,7 +73,7 @@ class StaticLayer(CacheLayerParent):
self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
self.dtype, self.device = key_states.dtype, key_states.device self.dtype, self.device = key_states.dtype, key_states.device
self.keys = torch.zeros( self.keys = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim), (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype, dtype=self.dtype,
@ -87,7 +87,7 @@ class StaticLayer(CacheLayerParent):
) )
def update(self, key_states, value_states, cache_kwargs = None): def update(self, key_states, value_states, cache_kwargs = None):
if self.keys is None: if self.keys is None:
self.lazy_init(key_states) self.lazy_init(key_states)
@ -104,7 +104,7 @@ class StaticLayer(CacheLayerParent):
self.keys[:, :, cache_position] = key_states self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states self.values[:, :, cache_position] = value_states
return self.keys, self.values return self.keys, self.values
class StaticCache(Cache): class StaticCache(Cache):
def __init__(self, config, max_cache_len): def __init__(self, config, max_cache_len):
self.layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)] self.layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)]