diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 870bff369..646a6ae93 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1828,7 +1828,7 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No bs, c, lat_t, lat_h, lat_w = x.shape frame_seq_len = (lat_h // 2) * (lat_w // 2) - num_blocks = lat_t // num_frame_per_block + num_blocks = -(-lat_t // num_frame_per_block) # ceiling division inner_model = model.inner_model.inner_model causal_model = inner_model.diffusion_model @@ -1845,49 +1845,49 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No total_real_steps = num_blocks * num_sigma_steps step_count = 0 - for block_idx in trange(num_blocks, disable=disable): - bf = num_frame_per_block - fs, fe = current_start_frame, current_start_frame + bf - noisy_input = x[:, :, fs:fe] + try: + for block_idx in trange(num_blocks, disable=disable): + bf = min(num_frame_per_block, lat_t - current_start_frame) + fs, fe = current_start_frame, current_start_frame + bf + noisy_input = x[:, :, fs:fe] - ar_state = { - "start_frame": current_start_frame, - "kv_caches": kv_caches, - "crossattn_caches": crossattn_caches, - } - transformer_options["ar_state"] = ar_state + ar_state = { + "start_frame": current_start_frame, + "kv_caches": kv_caches, + "crossattn_caches": crossattn_caches, + } + transformer_options["ar_state"] = ar_state - for i in range(num_sigma_steps): - denoised = model(noisy_input, sigmas[i] * s_in, **extra_args) + for i in range(num_sigma_steps): + denoised = model(noisy_input, sigmas[i] * s_in, **extra_args) - if callback is not None: - # Scale step_count to [0, num_sigma_steps) so the progress bar fills gradually - scaled_i = step_count * num_sigma_steps // total_real_steps - callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i], - "sigma_hat": sigmas[i], "denoised": denoised}) + if callback is not None: + scaled_i = step_count * num_sigma_steps // total_real_steps + callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i], + "sigma_hat": sigmas[i], "denoised": denoised}) - if sigmas[i + 1] == 0: - noisy_input = denoised - else: - sigma_next = sigmas[i + 1] - torch.manual_seed(seed + block_idx * 1000 + i) - fresh_noise = torch.randn_like(denoised) - noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise + if sigmas[i + 1] == 0: + noisy_input = denoised + else: + sigma_next = sigmas[i + 1] + torch.manual_seed(seed + block_idx * 1000 + i) + fresh_noise = torch.randn_like(denoised) + noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise - for cache in kv_caches: - cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) + for cache in kv_caches: + cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) - step_count += 1 + step_count += 1 - output[:, :, fs:fe] = noisy_input + output[:, :, fs:fe] = noisy_input - # Cache update: run model at t=0 with clean output to fill KV cache - for cache in kv_caches: - cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) - zero_sigma = sigmas.new_zeros([1]) - _ = model(noisy_input, zero_sigma * s_in, **extra_args) + for cache in kv_caches: + cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) + zero_sigma = sigmas.new_zeros([1]) + _ = model(noisy_input, zero_sigma * s_in, **extra_args) - current_start_frame += bf + current_start_frame += bf + finally: + transformer_options.pop("ar_state", None) - transformer_options.pop("ar_state", None) return output