diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index bfbc08357..d3ba8ad2e 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1252,24 +1252,21 @@ class LTXVModel(LTXBaseModel): dtype: Target dtype. Returns: - (1, 1, total_tokens, total_tokens) additive bias mask. + (1, 1, 1, total_tokens) additive bias mask. Broadcasts across queries + inside attention, dropping the persistent allocation from O(T²) to O(T). 0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked. """ finfo = torch.finfo(dtype) - mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype) + mask = torch.zeros((1, 1, 1, total_tokens), device=device, dtype=dtype) tracked_end = guide_start + tracked_count - # Convert weights to log-space bias w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count) log_w = torch.full_like(w, finfo.min) positive_mask = w > 0 if positive_mask.any(): log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny)) - # noisy → tracked guides: each noisy row gets the same per-guide weight - mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1) - # tracked guides → noisy: each guide row broadcasts its weight across noisy cols - mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1) + mask[:, :, :, guide_start:tracked_end] = log_w.view(1, 1, 1, -1) return mask