LTX2 context windows - Fix audio index mapping for wrapped/strided primary windows

The previous window-level calculation collapsed wrapped or strided primary windows into a contiguous audio tail, so audio attended to a different temporal region than the video. Replace with per-frame mapping that computes each primary index's audio span independently and concatenates in order.
This commit is contained in:
ozbayb 2026-04-12 14:52:54 -06:00
parent ae3830a6d2
commit f1f3182be1

View File

@ -1088,17 +1088,24 @@ class LTXAV(BaseModel):
return result
video_total = latent_shapes[0][dim]
video_window_len = len(primary_indices)
for i in range(1, len(latent_shapes)):
mod_total = latent_shapes[i][dim]
# Length proportional to video window frame count
mod_window_len = max(round(video_window_len * mod_total / video_total), 1)
# Anchor to end of video range
v_end = max(primary_indices) + 1
mod_end = min(round(v_end * mod_total / video_total), mod_total)
mod_start = max(mod_end - mod_window_len, 0)
result.append(list(range(mod_start, min(mod_start + mod_window_len, mod_total))))
# Map each primary index to its proportional range of modality indices and
# concatenate in order. Preserves wrapped/strided geometry so the modality
# attends to the same temporal regions as the primary window.
mod_indices = []
seen = set()
for v_idx in primary_indices:
a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1)
a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total)
if a_end <= a_start:
a_end = a_start + 1
for a in range(a_start, a_end):
if a not in seen:
seen.add(a)
mod_indices.append(a)
result.append(mod_indices)
return result