Alternative self_attn_mask

Drastically lower memory use, different effect, for testing
This commit is contained in:
kijai 2026-05-06 16:34:01 +03:00
parent 9c34f5f36a
commit e6e3e6f628

View File

@ -1252,24 +1252,21 @@ class LTXVModel(LTXBaseModel):
dtype: Target dtype.
Returns:
(1, 1, total_tokens, total_tokens) additive bias mask.
(1, 1, 1, total_tokens) additive bias mask. Broadcasts across queries
inside attention, dropping the persistent allocation from O() to O(T).
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
"""
finfo = torch.finfo(dtype)
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
mask = torch.zeros((1, 1, 1, total_tokens), device=device, dtype=dtype)
tracked_end = guide_start + tracked_count
# Convert weights to log-space bias
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
log_w = torch.full_like(w, finfo.min)
positive_mask = w > 0
if positive_mask.any():
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
# noisy → tracked guides: each noisy row gets the same per-guide weight
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
mask[:, :, :, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
return mask