mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-05 08:16:47 +08:00
Move KV cache end counter to Python int to avoid per-step host synchronization in AR sampling loops.
This commit is contained in:
parent
3440c57f67
commit
08bf8f4d95
@ -1875,14 +1875,14 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
|
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
|
||||||
|
|
||||||
for cache in kv_caches:
|
for cache in kv_caches:
|
||||||
cache["end"].fill_(cache["end"].item() - bf * frame_seq_len)
|
cache["end"] -= bf * frame_seq_len
|
||||||
|
|
||||||
step_count += 1
|
step_count += 1
|
||||||
|
|
||||||
output[:, :, fs:fe] = noisy_input
|
output[:, :, fs:fe] = noisy_input
|
||||||
|
|
||||||
for cache in kv_caches:
|
for cache in kv_caches:
|
||||||
cache["end"].fill_(cache["end"].item() - bf * frame_seq_len)
|
cache["end"] -= bf * frame_seq_len
|
||||||
zero_sigma = sigmas.new_zeros([1])
|
zero_sigma = sigmas.new_zeros([1])
|
||||||
_ = model(noisy_input, zero_sigma * s_in, **extra_args)
|
_ = model(noisy_input, zero_sigma * s_in, **extra_args)
|
||||||
|
|
||||||
|
|||||||
@ -64,13 +64,13 @@ class CausalWanSelfAttention(nn.Module):
|
|||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
end = kv_cache["end"].item()
|
end = kv_cache["end"]
|
||||||
new_end = end + s
|
new_end = end + s
|
||||||
|
|
||||||
# Roped K and plain V go into cache
|
# Roped K and plain V go into cache
|
||||||
kv_cache["k"][:, end:new_end] = k
|
kv_cache["k"][:, end:new_end] = k
|
||||||
kv_cache["v"][:, end:new_end] = v
|
kv_cache["v"][:, end:new_end] = v
|
||||||
kv_cache["end"].fill_(new_end)
|
kv_cache["end"] = new_end
|
||||||
|
|
||||||
x = optimized_attention(
|
x = optimized_attention(
|
||||||
q.view(b, s, n * d),
|
q.view(b, s, n * d),
|
||||||
@ -232,7 +232,7 @@ class CausalWanModel(WanModel):
|
|||||||
caches.append({
|
caches.append({
|
||||||
"k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
"k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
||||||
"v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
"v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
||||||
"end": torch.tensor([0], dtype=torch.long, device=device),
|
"end": 0,
|
||||||
})
|
})
|
||||||
return caches
|
return caches
|
||||||
|
|
||||||
@ -246,7 +246,7 @@ class CausalWanModel(WanModel):
|
|||||||
def reset_kv_caches(self, kv_caches):
|
def reset_kv_caches(self, kv_caches):
|
||||||
"""Reset KV caches to empty (reuse allocated memory)."""
|
"""Reset KV caches to empty (reuse allocated memory)."""
|
||||||
for cache in kv_caches:
|
for cache in kv_caches:
|
||||||
cache["end"].fill_(0)
|
cache["end"] = 0
|
||||||
|
|
||||||
def reset_crossattn_caches(self, crossattn_caches):
|
def reset_crossattn_caches(self, crossattn_caches):
|
||||||
"""Reset cross-attention caches."""
|
"""Reset cross-attention caches."""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user