mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 13:33:42 +08:00
Move KV cache end counter to Python int to avoid per-step host synchronization in AR sampling loops.
This commit is contained in:
parent
3440c57f67
commit
08bf8f4d95
@ -1875,14 +1875,14 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
|
||||
|
||||
for cache in kv_caches:
|
||||
cache["end"].fill_(cache["end"].item() - bf * frame_seq_len)
|
||||
cache["end"] -= bf * frame_seq_len
|
||||
|
||||
step_count += 1
|
||||
|
||||
output[:, :, fs:fe] = noisy_input
|
||||
|
||||
for cache in kv_caches:
|
||||
cache["end"].fill_(cache["end"].item() - bf * frame_seq_len)
|
||||
cache["end"] -= bf * frame_seq_len
|
||||
zero_sigma = sigmas.new_zeros([1])
|
||||
_ = model(noisy_input, zero_sigma * s_in, **extra_args)
|
||||
|
||||
|
||||
@ -64,13 +64,13 @@ class CausalWanSelfAttention(nn.Module):
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
else:
|
||||
end = kv_cache["end"].item()
|
||||
end = kv_cache["end"]
|
||||
new_end = end + s
|
||||
|
||||
# Roped K and plain V go into cache
|
||||
kv_cache["k"][:, end:new_end] = k
|
||||
kv_cache["v"][:, end:new_end] = v
|
||||
kv_cache["end"].fill_(new_end)
|
||||
kv_cache["end"] = new_end
|
||||
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
@ -232,7 +232,7 @@ class CausalWanModel(WanModel):
|
||||
caches.append({
|
||||
"k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
||||
"v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
||||
"end": torch.tensor([0], dtype=torch.long, device=device),
|
||||
"end": 0,
|
||||
})
|
||||
return caches
|
||||
|
||||
@ -246,7 +246,7 @@ class CausalWanModel(WanModel):
|
||||
def reset_kv_caches(self, kv_caches):
|
||||
"""Reset KV caches to empty (reuse allocated memory)."""
|
||||
for cache in kv_caches:
|
||||
cache["end"].fill_(0)
|
||||
cache["end"] = 0
|
||||
|
||||
def reset_crossattn_caches(self, crossattn_caches):
|
||||
"""Reset cross-attention caches."""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user