From 907ed37facdea08eb3059f9e3586f9badcf6019d Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sat, 25 Oct 2025 00:05:35 +0300 Subject: [PATCH] ruff checks --- comfy/autoregressive_sampling.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/comfy/autoregressive_sampling.py b/comfy/autoregressive_sampling.py index c8e8aa08e..d0ab90673 100644 --- a/comfy/autoregressive_sampling.py +++ b/comfy/autoregressive_sampling.py @@ -14,10 +14,10 @@ import comfy.model_management class CacheLayerParent(ABC): def __init__(self): self.keys, self.values = None, None - + @abstractmethod def lazy_init(self, key_states): ... - + @abstractmethod def update(self, key_states, value_states, cache_kwargs = None): ... @@ -40,7 +40,7 @@ class Cache: if layer_idx >= len(self.layers): return 0 return self.layers[layer_idx].get_seq_length() - + def get_max_cache_shape(self, layer_idx = 0): if layer_idx >= len(self.layers): # for future dynamic cache impl. return -1 @@ -52,7 +52,7 @@ class Cache: def update(self, key_states, value_states, layer_idx, cache_kwargs): keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs) return keys, values - + @property def max_cache_len(self): values = [layer.max_cache_len for layer in self.layers] @@ -73,7 +73,7 @@ class StaticLayer(CacheLayerParent): self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape self.dtype, self.device = key_states.dtype, key_states.device - + self.keys = torch.zeros( (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim), dtype=self.dtype, @@ -87,7 +87,7 @@ class StaticLayer(CacheLayerParent): ) def update(self, key_states, value_states, cache_kwargs = None): - + if self.keys is None: self.lazy_init(key_states) @@ -104,7 +104,7 @@ class StaticLayer(CacheLayerParent): self.keys[:, :, cache_position] = key_states self.values[:, :, cache_position] = value_states return self.keys, self.values - + class StaticCache(Cache): def __init__(self, config, max_cache_len): self.layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)]