7b specific

This commit is contained in:
Yousef R. Gamaleldin 2026-01-23 01:24:15 +02:00
parent e3fa1aa415
commit ba59569e4b

View File

@ -721,6 +721,7 @@ class NaSwinAttention(NaMMAttention):
**kwargs,
):
super().__init__(*args, **kwargs)
self.version_7b = kwargs.get("version", False)
self.window = _triple(window)
self.window_method = window_method
assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window))
@ -775,28 +776,32 @@ class NaSwinAttention(NaMMAttention):
)
# window rope
if self.rope:
if self.rope.mm:
# repeat text q and k for window mmrope
_, num_h, _ = txt_q.shape
txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)")
txt_q_repeat = unflatten(txt_q_repeat, txt_shape)
txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)]
txt_q_repeat = list(chain(*txt_q_repeat))
txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat)
txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h)
if not self.version_7b:
if self.rope:
if self.rope.mm:
# repeat text q and k for window mmrope
_, num_h, _ = txt_q.shape
txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)")
txt_q_repeat = unflatten(txt_q_repeat, txt_shape)
txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)]
txt_q_repeat = list(chain(*txt_q_repeat))
txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat)
txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h)
txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)")
txt_k_repeat = unflatten(txt_k_repeat, txt_shape)
txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)]
txt_k_repeat = list(chain(*txt_k_repeat))
txt_k_repeat, _ = flatten(txt_k_repeat)
txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h)
txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)")
txt_k_repeat = unflatten(txt_k_repeat, txt_shape)
txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)]
txt_k_repeat = list(chain(*txt_k_repeat))
txt_k_repeat, _ = flatten(txt_k_repeat)
txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h)
vid_q, vid_k, txt_q, txt_k = self.rope(
vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win
)
else:
vid_q, vid_k, txt_q, txt_k = self.rope(
vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win
)
else:
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
else:
if self.rope:
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
out = optimized_attention(
@ -899,6 +904,7 @@ class NaMMSRTransformerBlock(nn.Module):
**kwargs,
):
super().__init__()
version = kwargs.get("version", False)
dim = MMArg(vid_dim, txt_dim)
self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype)
@ -915,6 +921,7 @@ class NaMMSRTransformerBlock(nn.Module):
shared_weights=shared_weights,
window=kwargs.pop("window", None),
window_method=kwargs.pop("window_method", None),
version=version,
device=device, dtype=dtype, operations=operations
)
@ -929,6 +936,7 @@ class NaMMSRTransformerBlock(nn.Module):
)
self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype)
self.is_last_layer = is_last_layer
self.version = version
def forward(
self,
@ -1260,6 +1268,7 @@ class NaDiT(nn.Module):
operations = None,
**kwargs,
):
self._7b_version = vid_dim == 3072
self.dtype = dtype
factory_kwargs = {"device": device, "dtype": dtype}
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
@ -1333,6 +1342,7 @@ class NaDiT(nn.Module):
shared_weights=not (
(i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i]
),
version = self._7b_version,
operations = operations,
**kwargs,
**factory_kwargs