mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 04:40:15 +08:00
ruff checks
This commit is contained in:
parent
76fdb4bc9a
commit
907ed37fac
@ -14,10 +14,10 @@ import comfy.model_management
|
|||||||
class CacheLayerParent(ABC):
|
class CacheLayerParent(ABC):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.keys, self.values = None, None
|
self.keys, self.values = None, None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def lazy_init(self, key_states): ...
|
def lazy_init(self, key_states): ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update(self, key_states, value_states, cache_kwargs = None): ...
|
def update(self, key_states, value_states, cache_kwargs = None): ...
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ class Cache:
|
|||||||
if layer_idx >= len(self.layers):
|
if layer_idx >= len(self.layers):
|
||||||
return 0
|
return 0
|
||||||
return self.layers[layer_idx].get_seq_length()
|
return self.layers[layer_idx].get_seq_length()
|
||||||
|
|
||||||
def get_max_cache_shape(self, layer_idx = 0):
|
def get_max_cache_shape(self, layer_idx = 0):
|
||||||
if layer_idx >= len(self.layers): # for future dynamic cache impl.
|
if layer_idx >= len(self.layers): # for future dynamic cache impl.
|
||||||
return -1
|
return -1
|
||||||
@ -52,7 +52,7 @@ class Cache:
|
|||||||
def update(self, key_states, value_states, layer_idx, cache_kwargs):
|
def update(self, key_states, value_states, layer_idx, cache_kwargs):
|
||||||
keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
|
keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
|
||||||
return keys, values
|
return keys, values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_cache_len(self):
|
def max_cache_len(self):
|
||||||
values = [layer.max_cache_len for layer in self.layers]
|
values = [layer.max_cache_len for layer in self.layers]
|
||||||
@ -73,7 +73,7 @@ class StaticLayer(CacheLayerParent):
|
|||||||
|
|
||||||
self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
|
self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
|
||||||
self.dtype, self.device = key_states.dtype, key_states.device
|
self.dtype, self.device = key_states.dtype, key_states.device
|
||||||
|
|
||||||
self.keys = torch.zeros(
|
self.keys = torch.zeros(
|
||||||
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
|
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
@ -87,7 +87,7 @@ class StaticLayer(CacheLayerParent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def update(self, key_states, value_states, cache_kwargs = None):
|
def update(self, key_states, value_states, cache_kwargs = None):
|
||||||
|
|
||||||
if self.keys is None:
|
if self.keys is None:
|
||||||
self.lazy_init(key_states)
|
self.lazy_init(key_states)
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ class StaticLayer(CacheLayerParent):
|
|||||||
self.keys[:, :, cache_position] = key_states
|
self.keys[:, :, cache_position] = key_states
|
||||||
self.values[:, :, cache_position] = value_states
|
self.values[:, :, cache_position] = value_states
|
||||||
return self.keys, self.values
|
return self.keys, self.values
|
||||||
|
|
||||||
class StaticCache(Cache):
|
class StaticCache(Cache):
|
||||||
def __init__(self, config, max_cache_len):
|
def __init__(self, config, max_cache_len):
|
||||||
self.layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)]
|
self.layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user