mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 16:50:57 +08:00
ruff checks
This commit is contained in:
parent
76fdb4bc9a
commit
907ed37fac
@ -14,10 +14,10 @@ import comfy.model_management
|
||||
class CacheLayerParent(ABC):
|
||||
def __init__(self):
|
||||
self.keys, self.values = None, None
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def lazy_init(self, key_states): ...
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def update(self, key_states, value_states, cache_kwargs = None): ...
|
||||
|
||||
@ -40,7 +40,7 @@ class Cache:
|
||||
if layer_idx >= len(self.layers):
|
||||
return 0
|
||||
return self.layers[layer_idx].get_seq_length()
|
||||
|
||||
|
||||
def get_max_cache_shape(self, layer_idx = 0):
|
||||
if layer_idx >= len(self.layers): # for future dynamic cache impl.
|
||||
return -1
|
||||
@ -52,7 +52,7 @@ class Cache:
|
||||
def update(self, key_states, value_states, layer_idx, cache_kwargs):
|
||||
keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
|
||||
return keys, values
|
||||
|
||||
|
||||
@property
|
||||
def max_cache_len(self):
|
||||
values = [layer.max_cache_len for layer in self.layers]
|
||||
@ -73,7 +73,7 @@ class StaticLayer(CacheLayerParent):
|
||||
|
||||
self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
|
||||
self.dtype, self.device = key_states.dtype, key_states.device
|
||||
|
||||
|
||||
self.keys = torch.zeros(
|
||||
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
|
||||
dtype=self.dtype,
|
||||
@ -87,7 +87,7 @@ class StaticLayer(CacheLayerParent):
|
||||
)
|
||||
|
||||
def update(self, key_states, value_states, cache_kwargs = None):
|
||||
|
||||
|
||||
if self.keys is None:
|
||||
self.lazy_init(key_states)
|
||||
|
||||
@ -104,7 +104,7 @@ class StaticLayer(CacheLayerParent):
|
||||
self.keys[:, :, cache_position] = key_states
|
||||
self.values[:, :, cache_position] = value_states
|
||||
return self.keys, self.values
|
||||
|
||||
|
||||
class StaticCache(Cache):
|
||||
def __init__(self, config, max_cache_len):
|
||||
self.layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user