Fix 'Process the tail block instead of truncating it', fix 'Don't mutate the patcher's shared transformer_options in place'.

This commit is contained in:
Talmaj Marinc 2026-03-25 21:13:23 +01:00
parent 4b2734889c
commit de66e64ec2

View File

@ -1828,7 +1828,7 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
bs, c, lat_t, lat_h, lat_w = x.shape bs, c, lat_t, lat_h, lat_w = x.shape
frame_seq_len = (lat_h // 2) * (lat_w // 2) frame_seq_len = (lat_h // 2) * (lat_w // 2)
num_blocks = lat_t // num_frame_per_block num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
inner_model = model.inner_model.inner_model inner_model = model.inner_model.inner_model
causal_model = inner_model.diffusion_model causal_model = inner_model.diffusion_model
@ -1845,8 +1845,9 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
total_real_steps = num_blocks * num_sigma_steps total_real_steps = num_blocks * num_sigma_steps
step_count = 0 step_count = 0
try:
for block_idx in trange(num_blocks, disable=disable): for block_idx in trange(num_blocks, disable=disable):
bf = num_frame_per_block bf = min(num_frame_per_block, lat_t - current_start_frame)
fs, fe = current_start_frame, current_start_frame + bf fs, fe = current_start_frame, current_start_frame + bf
noisy_input = x[:, :, fs:fe] noisy_input = x[:, :, fs:fe]
@ -1861,7 +1862,6 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
denoised = model(noisy_input, sigmas[i] * s_in, **extra_args) denoised = model(noisy_input, sigmas[i] * s_in, **extra_args)
if callback is not None: if callback is not None:
# Scale step_count to [0, num_sigma_steps) so the progress bar fills gradually
scaled_i = step_count * num_sigma_steps // total_real_steps scaled_i = step_count * num_sigma_steps // total_real_steps
callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i], callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i],
"sigma_hat": sigmas[i], "denoised": denoised}) "sigma_hat": sigmas[i], "denoised": denoised})
@ -1881,13 +1881,13 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
output[:, :, fs:fe] = noisy_input output[:, :, fs:fe] = noisy_input
# Cache update: run model at t=0 with clean output to fill KV cache
for cache in kv_caches: for cache in kv_caches:
cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) cache["end"].fill_(cache["end"].item() - bf * frame_seq_len)
zero_sigma = sigmas.new_zeros([1]) zero_sigma = sigmas.new_zeros([1])
_ = model(noisy_input, zero_sigma * s_in, **extra_args) _ = model(noisy_input, zero_sigma * s_in, **extra_args)
current_start_frame += bf current_start_frame += bf
finally:
transformer_options.pop("ar_state", None) transformer_options.pop("ar_state", None)
return output return output