diff --git a/comfy/model_base.py b/comfy/model_base.py index 08f8b058d..18b3f5c0e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1088,17 +1088,24 @@ class LTXAV(BaseModel): return result video_total = latent_shapes[0][dim] - video_window_len = len(primary_indices) for i in range(1, len(latent_shapes)): mod_total = latent_shapes[i][dim] - # Length proportional to video window frame count - mod_window_len = max(round(video_window_len * mod_total / video_total), 1) - # Anchor to end of video range - v_end = max(primary_indices) + 1 - mod_end = min(round(v_end * mod_total / video_total), mod_total) - mod_start = max(mod_end - mod_window_len, 0) - result.append(list(range(mod_start, min(mod_start + mod_window_len, mod_total)))) + # Map each primary index to its proportional range of modality indices and + # concatenate in order. Preserves wrapped/strided geometry so the modality + # attends to the same temporal regions as the primary window. + mod_indices = [] + seen = set() + for v_idx in primary_indices: + a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1) + a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total) + if a_end <= a_start: + a_end = a_start + 1 + for a in range(a_start, a_end): + if a not in seen: + seen.add(a) + mod_indices.append(a) + result.append(mod_indices) return result