Move KV cache end counter to Python int to avoid per-step host synchronization in AR sampling loops.

This commit is contained in:
Talmaj Marinc 2026-03-25 21:59:22 +01:00
parent 3440c57f67
commit 08bf8f4d95
2 changed files with 6 additions and 6 deletions

View File

@ -1875,14 +1875,14 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
for cache in kv_caches:
cache["end"].fill_(cache["end"].item() - bf * frame_seq_len)
cache["end"] -= bf * frame_seq_len
step_count += 1
output[:, :, fs:fe] = noisy_input
for cache in kv_caches:
cache["end"].fill_(cache["end"].item() - bf * frame_seq_len)
cache["end"] -= bf * frame_seq_len
zero_sigma = sigmas.new_zeros([1])
_ = model(noisy_input, zero_sigma * s_in, **extra_args)

View File

@ -64,13 +64,13 @@ class CausalWanSelfAttention(nn.Module):
transformer_options=transformer_options,
)
else:
end = kv_cache["end"].item()
end = kv_cache["end"]
new_end = end + s
# Roped K and plain V go into cache
kv_cache["k"][:, end:new_end] = k
kv_cache["v"][:, end:new_end] = v
kv_cache["end"].fill_(new_end)
kv_cache["end"] = new_end
x = optimized_attention(
q.view(b, s, n * d),
@ -232,7 +232,7 @@ class CausalWanModel(WanModel):
caches.append({
"k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
"v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
"end": torch.tensor([0], dtype=torch.long, device=device),
"end": 0,
})
return caches
@ -246,7 +246,7 @@ class CausalWanModel(WanModel):
def reset_kv_caches(self, kv_caches):
"""Reset KV caches to empty (reuse allocated memory)."""
for cache in kv_caches:
cache["end"].fill_(0)
cache["end"] = 0
def reset_crossattn_caches(self, crossattn_caches):
"""Reset cross-attention caches."""