mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 13:33:42 +08:00
Fix 'Process the tail block instead of truncating it', fix 'Don't mutate the patcher's shared transformer_options in place'.
This commit is contained in:
parent
4b2734889c
commit
de66e64ec2
@ -1828,7 +1828,7 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
|
|
||||||
bs, c, lat_t, lat_h, lat_w = x.shape
|
bs, c, lat_t, lat_h, lat_w = x.shape
|
||||||
frame_seq_len = (lat_h // 2) * (lat_w // 2)
|
frame_seq_len = (lat_h // 2) * (lat_w // 2)
|
||||||
num_blocks = lat_t // num_frame_per_block
|
num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
|
||||||
|
|
||||||
inner_model = model.inner_model.inner_model
|
inner_model = model.inner_model.inner_model
|
||||||
causal_model = inner_model.diffusion_model
|
causal_model = inner_model.diffusion_model
|
||||||
@ -1845,49 +1845,49 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
total_real_steps = num_blocks * num_sigma_steps
|
total_real_steps = num_blocks * num_sigma_steps
|
||||||
step_count = 0
|
step_count = 0
|
||||||
|
|
||||||
for block_idx in trange(num_blocks, disable=disable):
|
try:
|
||||||
bf = num_frame_per_block
|
for block_idx in trange(num_blocks, disable=disable):
|
||||||
fs, fe = current_start_frame, current_start_frame + bf
|
bf = min(num_frame_per_block, lat_t - current_start_frame)
|
||||||
noisy_input = x[:, :, fs:fe]
|
fs, fe = current_start_frame, current_start_frame + bf
|
||||||
|
noisy_input = x[:, :, fs:fe]
|
||||||
|
|
||||||
ar_state = {
|
ar_state = {
|
||||||
"start_frame": current_start_frame,
|
"start_frame": current_start_frame,
|
||||||
"kv_caches": kv_caches,
|
"kv_caches": kv_caches,
|
||||||
"crossattn_caches": crossattn_caches,
|
"crossattn_caches": crossattn_caches,
|
||||||
}
|
}
|
||||||
transformer_options["ar_state"] = ar_state
|
transformer_options["ar_state"] = ar_state
|
||||||
|
|
||||||
for i in range(num_sigma_steps):
|
for i in range(num_sigma_steps):
|
||||||
denoised = model(noisy_input, sigmas[i] * s_in, **extra_args)
|
denoised = model(noisy_input, sigmas[i] * s_in, **extra_args)
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
# Scale step_count to [0, num_sigma_steps) so the progress bar fills gradually
|
scaled_i = step_count * num_sigma_steps // total_real_steps
|
||||||
scaled_i = step_count * num_sigma_steps // total_real_steps
|
callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i],
|
||||||
callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i],
|
"sigma_hat": sigmas[i], "denoised": denoised})
|
||||||
"sigma_hat": sigmas[i], "denoised": denoised})
|
|
||||||
|
|
||||||
if sigmas[i + 1] == 0:
|
if sigmas[i + 1] == 0:
|
||||||
noisy_input = denoised
|
noisy_input = denoised
|
||||||
else:
|
else:
|
||||||
sigma_next = sigmas[i + 1]
|
sigma_next = sigmas[i + 1]
|
||||||
torch.manual_seed(seed + block_idx * 1000 + i)
|
torch.manual_seed(seed + block_idx * 1000 + i)
|
||||||
fresh_noise = torch.randn_like(denoised)
|
fresh_noise = torch.randn_like(denoised)
|
||||||
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
|
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
|
||||||
|
|
||||||
for cache in kv_caches:
|
for cache in kv_caches:
|
||||||
cache["end"].fill_(cache["end"].item() - bf * frame_seq_len)
|
cache["end"].fill_(cache["end"].item() - bf * frame_seq_len)
|
||||||
|
|
||||||
step_count += 1
|
step_count += 1
|
||||||
|
|
||||||
output[:, :, fs:fe] = noisy_input
|
output[:, :, fs:fe] = noisy_input
|
||||||
|
|
||||||
# Cache update: run model at t=0 with clean output to fill KV cache
|
for cache in kv_caches:
|
||||||
for cache in kv_caches:
|
cache["end"].fill_(cache["end"].item() - bf * frame_seq_len)
|
||||||
cache["end"].fill_(cache["end"].item() - bf * frame_seq_len)
|
zero_sigma = sigmas.new_zeros([1])
|
||||||
zero_sigma = sigmas.new_zeros([1])
|
_ = model(noisy_input, zero_sigma * s_in, **extra_args)
|
||||||
_ = model(noisy_input, zero_sigma * s_in, **extra_args)
|
|
||||||
|
|
||||||
current_start_frame += bf
|
current_start_frame += bf
|
||||||
|
finally:
|
||||||
|
transformer_options.pop("ar_state", None)
|
||||||
|
|
||||||
transformer_options.pop("ar_state", None)
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user