Small fixes

This commit is contained in:
Mihail Karaev 2025-12-17 11:40:14 +00:00
parent a3f78be5c2
commit 296b7c7b6d
2 changed files with 6 additions and 3 deletions

View File

@ -391,19 +391,22 @@ class Kandinsky5(nn.Module):
if T > NABLA_THR:
assert self.patch_size[0] == 1
# pro video model uses lower P at higher resolutions
P = 0.7 if self.model_dim == 4096 and H * W >= 14080 else 0.9
freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])
visual_embed, freqs = fractal_flatten(visual_embed, freqs, visual_shape[1:])
pt, ph, pw = self.patch_size
T, H, W = T // pt, H // ph, W // pw
wT, wW, wH = 11, 11, 3
wT, wW, wH = 11, 3, 3
sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW, device=x.device)
sparse_params = dict(
sta_mask=sta_mask.unsqueeze_(0).unsqueeze_(0),
attention_type="nabla",
to_fractal=True,
P=0.8,
P=P,
wT=wT, wW=wW, wH=wH,
add_sta=True,
visual_shape=(T, H, W),

View File

@ -143,4 +143,4 @@ def nabla(query, key, value, sparse_params=None):
.contiguous()
)
out = out.flatten(-2, -1)
return out
return out