Base frame_seq_len on the padded token grid.

This commit is contained in:
Talmaj Marinc 2026-03-25 22:05:12 +01:00
parent 08bf8f4d95
commit e9cf4659d2

View File

@ -1827,8 +1827,8 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
seed = extra_args.get("seed", 0)
bs, c, lat_t, lat_h, lat_w = x.shape
frame_seq_len = (lat_h // 2) * (lat_w // 2)
num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division
num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
inner_model = model.inner_model.inner_model
causal_model = inner_model.diffusion_model