From 08bf8f4d958dedbadbc768030b80c13862c25848 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Wed, 25 Mar 2026 21:59:22 +0100 Subject: [PATCH] Move KV cache end counter to Python int to avoid per-step host synchronization in AR sampling loops. --- comfy/k_diffusion/sampling.py | 4 ++-- comfy/ldm/wan/ar_model.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 646a6ae93..5bab263bd 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1875,14 +1875,14 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise for cache in kv_caches: - cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) + cache["end"] -= bf * frame_seq_len step_count += 1 output[:, :, fs:fe] = noisy_input for cache in kv_caches: - cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) + cache["end"] -= bf * frame_seq_len zero_sigma = sigmas.new_zeros([1]) _ = model(noisy_input, zero_sigma * s_in, **extra_args) diff --git a/comfy/ldm/wan/ar_model.py b/comfy/ldm/wan/ar_model.py index 54a2ef704..d72f53602 100644 --- a/comfy/ldm/wan/ar_model.py +++ b/comfy/ldm/wan/ar_model.py @@ -64,13 +64,13 @@ class CausalWanSelfAttention(nn.Module): transformer_options=transformer_options, ) else: - end = kv_cache["end"].item() + end = kv_cache["end"] new_end = end + s # Roped K and plain V go into cache kv_cache["k"][:, end:new_end] = k kv_cache["v"][:, end:new_end] = v - kv_cache["end"].fill_(new_end) + kv_cache["end"] = new_end x = optimized_attention( q.view(b, s, n * d), @@ -232,7 +232,7 @@ class CausalWanModel(WanModel): caches.append({ "k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype), "v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype), - "end": torch.tensor([0], dtype=torch.long, device=device), + "end": 0, }) return caches @@ -246,7 +246,7 @@ class CausalWanModel(WanModel): def reset_kv_caches(self, kv_caches): """Reset KV caches to empty (reuse allocated memory).""" for cache in kv_caches: - cache["end"].fill_(0) + cache["end"] = 0 def reset_crossattn_caches(self, crossattn_caches): """Reset cross-attention caches."""