mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Small fixes
This commit is contained in:
parent
a3f78be5c2
commit
296b7c7b6d
@ -391,19 +391,22 @@ class Kandinsky5(nn.Module):
|
|||||||
if T > NABLA_THR:
|
if T > NABLA_THR:
|
||||||
assert self.patch_size[0] == 1
|
assert self.patch_size[0] == 1
|
||||||
|
|
||||||
|
# pro video model uses lower P at higher resolutions
|
||||||
|
P = 0.7 if self.model_dim == 4096 and H * W >= 14080 else 0.9
|
||||||
|
|
||||||
freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])
|
freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])
|
||||||
visual_embed, freqs = fractal_flatten(visual_embed, freqs, visual_shape[1:])
|
visual_embed, freqs = fractal_flatten(visual_embed, freqs, visual_shape[1:])
|
||||||
pt, ph, pw = self.patch_size
|
pt, ph, pw = self.patch_size
|
||||||
T, H, W = T // pt, H // ph, W // pw
|
T, H, W = T // pt, H // ph, W // pw
|
||||||
|
|
||||||
wT, wW, wH = 11, 11, 3
|
wT, wW, wH = 11, 3, 3
|
||||||
sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW, device=x.device)
|
sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW, device=x.device)
|
||||||
|
|
||||||
sparse_params = dict(
|
sparse_params = dict(
|
||||||
sta_mask=sta_mask.unsqueeze_(0).unsqueeze_(0),
|
sta_mask=sta_mask.unsqueeze_(0).unsqueeze_(0),
|
||||||
attention_type="nabla",
|
attention_type="nabla",
|
||||||
to_fractal=True,
|
to_fractal=True,
|
||||||
P=0.8,
|
P=P,
|
||||||
wT=wT, wW=wW, wH=wH,
|
wT=wT, wW=wW, wH=wH,
|
||||||
add_sta=True,
|
add_sta=True,
|
||||||
visual_shape=(T, H, W),
|
visual_shape=(T, H, W),
|
||||||
|
|||||||
@ -143,4 +143,4 @@ def nabla(query, key, value, sparse_params=None):
|
|||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
out = out.flatten(-2, -1)
|
out = out.flatten(-2, -1)
|
||||||
return out
|
return out
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user