testing the model

This commit is contained in:
Yousef Rafat 2025-12-07 23:43:49 +02:00
parent 4b9332cc21
commit 44a5bf353a
2 changed files with 119 additions and 61 deletions

View File

@ -428,7 +428,7 @@ else:
SDP_BATCH_LIMIT = 2**31 SDP_BATCH_LIMIT = 2**31
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=Falsez): def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
else: else:

View File

@ -145,56 +145,77 @@ def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()}
def make_720Pwindows(size, num_windows, shift = False): def get_window_op(name: str):
if name == "720pwin_by_size_bysize":
return make_720Pwindows_bysize
if name == "720pswin_by_size_bysize":
return make_shifted_720Pwindows_bysize
raise ValueError(f"Unknown windowing method: {name}")
# -------------------------------- Windowing -------------------------------- #
def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]):
t, h, w = size t, h, w = size
resized_nt, resized_nh, resized_nw = num_windows resized_nt, resized_nh, resized_nw = num_windows
#cal windows under 720p
scale = sqrt((45 * 80) / (h * w)) scale = math.sqrt((45 * 80) / (h * w))
resized_h, resized_w = round(h * scale), round(w * scale) resized_h, resized_w = round(h * scale), round(w * scale)
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
wt = ceil(min(t, 30) / resized_nt) # window size.
nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size.
return [
(
slice(it * wt, min((it + 1) * wt, t)),
slice(ih * wh, min((ih + 1) * wh, h)),
slice(iw * ww, min((iw + 1) * ww, w)),
)
for iw in range(nw)
if min((iw + 1) * ww, w) > iw * ww
for ih in range(nh)
if min((ih + 1) * wh, h) > ih * wh
for it in range(nt)
if min((it + 1) * wt, t) > it * wt
]
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]):
wt = ceil(min(t, 30) / resized_nt) t, h, w = size
resized_nt, resized_nh, resized_nw = num_windows
#cal windows under 720p
scale = math.sqrt((45 * 80) / (h * w))
resized_h, resized_w = round(h * scale), round(w * scale)
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
wt = ceil(min(t, 30) / resized_nt) # window size.
st, sh, sw = (0.5 * shift if wt < t else 0, st, sh, sw = ( # shift size.
0.5 * shift if wh < h else 0, 0.5 if wt < t else 0,
0.5 * shift if ww < w else 0) 0.5 if wh < h else 0,
0.5 if ww < w else 0,
nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) )
if shift: nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size.
nt += 1 if st > 0 else 0 nt, nh, nw = ( # number of window.
nh += 1 if sh > 0 else 0 nt + 1 if st > 0 else 1,
nw += 1 if sw > 0 else 0 nh + 1 if sh > 0 else 1,
nw + 1 if sw > 0 else 1,
windows = [] )
for iw in range(nw): return [
w_start = max(int((iw - sw) * ww), 0) (
w_end = min(int((iw - sw + 1) * ww), w) slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)),
if w_end <= w_start: slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)),
continue slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)),
)
for ih in range(nh): for iw in range(nw)
h_start = max(int((ih - sh) * wh), 0) if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0)
h_end = min(int((ih - sh + 1) * wh), h) for ih in range(nh)
if h_end <= h_start: if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0)
continue for it in range(nt)
if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0)
for it in range(nt): ]
t_start = max(int((it - st) * wt), 0)
t_end = min(int((it - st + 1) * wt), t)
if t_end <= t_start:
continue
windows.append((slice(t_start, t_end),
slice(h_start, h_end),
slice(w_start, w_end)))
return windows
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
def __init__( def __init__(
self, self,
dim, dim,
custom_freqs, custom_freqs = None,
freqs_for = 'lang', freqs_for = 'lang',
theta = 10000, theta = 10000,
max_freq = 10, max_freq = 10,
@ -566,6 +587,7 @@ class NaMMAttention(nn.Module):
): ):
super().__init__() super().__init__()
dim = MMArg(vid_dim, txt_dim) dim = MMArg(vid_dim, txt_dim)
self.heads = heads
inner_dim = heads * head_dim inner_dim = heads * head_dim
qkv_dim = inner_dim * 3 qkv_dim = inner_dim * 3
self.head_dim = head_dim self.head_dim = head_dim
@ -575,19 +597,20 @@ class NaMMAttention(nn.Module):
self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_weights) self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_weights)
self.norm_q = MMModule( self.norm_q = MMModule(
qk_norm, qk_norm,
dim=head_dim, normalized_shape=head_dim,
eps=qk_norm_eps, eps=qk_norm_eps,
elementwise_affine=True, elementwise_affine=True,
shared_weights=shared_weights, shared_weights=shared_weights,
) )
self.norm_k = MMModule( self.norm_k = MMModule(
qk_norm, qk_norm,
dim=head_dim, normalized_shape=head_dim,
eps=qk_norm_eps, eps=qk_norm_eps,
elementwise_affine=True, elementwise_affine=True,
shared_weights=shared_weights, shared_weights=shared_weights,
) )
self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim)
def forward( def forward(
@ -634,7 +657,7 @@ class NaMMAttention(nn.Module):
_, unconcat = cache("mm_pnp", lambda: concat_idx(vid_len, txt_len)) _, unconcat = cache("mm_pnp", lambda: concat_idx(vid_len, txt_len))
attn = optimized_attention(q, k, v, skip_reshape=True, skip_output_reshape=True) attn = optimized_attention(q, k, v, heads = self.heads, skip_reshape=True, skip_output_reshape=True)
attn = attn.flatten(0, 1) # to continue working with the rest of the code attn = attn.flatten(0, 1) # to continue working with the rest of the code
attn = rearrange(attn, "l h d -> l (h d)") attn = rearrange(attn, "l h d -> l (h d)")
@ -682,7 +705,7 @@ class NaSwinAttention(NaMMAttention):
self.window_method = window_method self.window_method = window_method
assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window))
self.window_op = window_method self.window_op = get_window_op(window_method)
def forward( def forward(
self, self,
@ -755,19 +778,16 @@ class NaSwinAttention(NaMMAttention):
else: else:
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
out = self.attn( # TODO: continue testing
q=concat_win(vid_q, txt_q).bfloat16(), b = len(vid_len_win)
k=concat_win(vid_k, txt_k).bfloat16(), vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)]
v=concat_win(vid_v, txt_v).bfloat16(), tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)]
cu_seqlens_q=cache_win(
"vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() q = torch.cat([vq, tq], dim=1)
), k = torch.cat([vk, tk], dim=1)
cu_seqlens_k=cache_win( v = torch.cat([vv, tv], dim=1)
"vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() out = optimized_attention(q, k, v, heads=self.heads, skip_reshape=True, skip_output_reshape=True)
), out = out.flatten(0, 1)
max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()),
max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()),
).type_as(vid_q)
# text pooling # text pooling
vid_out, txt_out = unconcat_win(out) vid_out, txt_out = unconcat_win(out)
@ -847,7 +867,7 @@ class NaMMSRTransformerBlock(nn.Module):
): ):
super().__init__() super().__init__()
dim = MMArg(vid_dim, txt_dim) dim = MMArg(vid_dim, txt_dim)
self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,)
self.attn = NaSwinAttention( self.attn = NaSwinAttention(
vid_dim=vid_dim, vid_dim=vid_dim,
@ -864,7 +884,7 @@ class NaMMSRTransformerBlock(nn.Module):
window_method=kwargs.pop("window_method", None), window_method=kwargs.pop("window_method", None),
) )
self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer)
self.mlp = MMModule( self.mlp = MMModule(
get_mlp(mlp_type), get_mlp(mlp_type),
dim=dim, dim=dim,
@ -1155,6 +1175,7 @@ class NaDiT(nn.Module):
txt_in_dim = 5120, txt_in_dim = 5120,
heads = 20, heads = 20,
head_dim = 128, head_dim = 128,
mm_layers = 10,
expand_ratio = 4, expand_ratio = 4,
qk_bias = False, qk_bias = False,
patch_size = [ 1,2,2 ], patch_size = [ 1,2,2 ],
@ -1163,8 +1184,12 @@ class NaDiT(nn.Module):
window_method: Optional[Tuple[str]] = None, window_method: Optional[Tuple[str]] = None,
temporal_window_size: int = None, temporal_window_size: int = None,
temporal_shifted: bool = False, temporal_shifted: bool = False,
rope_dim = 128,
rope_type = "mmrope3d",
vid_out_norm: Optional[str] = None,
**kwargs, **kwargs,
): ):
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
txt_dim = vid_dim txt_dim = vid_dim
emb_dim = vid_dim * 6 emb_dim = vid_dim * 6
block_type = ["mmdit_sr"] * num_layers block_type = ["mmdit_sr"] * num_layers
@ -1202,6 +1227,7 @@ class NaDiT(nn.Module):
if temporal_shifted is None or isinstance(temporal_shifted, bool): if temporal_shifted is None or isinstance(temporal_shifted, bool):
temporal_shifted = [temporal_shifted] * num_layers temporal_shifted = [temporal_shifted] * num_layers
rope_dim = rope_dim if rope_dim is not None else head_dim // 2
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
NaMMSRTransformerBlock( NaMMSRTransformerBlock(
@ -1220,10 +1246,16 @@ class NaDiT(nn.Module):
shared_qkv=shared_qkv, shared_qkv=shared_qkv,
shared_mlp=shared_mlp, shared_mlp=shared_mlp,
mlp_type=mlp_type, mlp_type=mlp_type,
rope_dim = rope_dim,
window=window[i], window=window[i],
window_method=window_method[i], window_method=window_method[i],
temporal_window_size=temporal_window_size[i], temporal_window_size=temporal_window_size[i],
temporal_shifted=temporal_shifted[i], temporal_shifted=temporal_shifted[i],
is_last_layer=(i == num_layers - 1),
rope_type = rope_type,
shared_weights=not (
(i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i]
),
**kwargs, **kwargs,
) )
for i in range(num_layers) for i in range(num_layers)
@ -1241,6 +1273,20 @@ class NaDiT(nn.Module):
"mmdit_stwin_3d_spatial", "mmdit_stwin_3d_spatial",
] ]
self.vid_out_norm = None
if vid_out_norm is not None:
self.vid_out_norm = RMSNorm(
normalized_shape=vid_dim,
eps=norm_eps,
elementwise_affine=True,
)
self.vid_out_ada = ada(
dim=vid_dim,
emb_dim=emb_dim,
layers=["out"],
modes=["in"],
)
def forward( def forward(
self, self,
x, x,
@ -1284,6 +1330,18 @@ class NaDiT(nn.Module):
cache=cache, cache=cache,
) )
if self.vid_out_norm:
vid = self.vid_out_norm(vid)
vid = self.vid_out_ada(
vid,
emb=emb,
layer="out",
mode="in",
hid_len=cache("vid_len", lambda: vid_shape.prod(-1)),
cache=cache,
branch_tag="vid",
)
vid, vid_shape = self.vid_out(vid, vid_shape, cache) vid, vid_shape = self.vid_out(vid, vid_shape, cache)
vid = unflatten(vid, vid_shape) vid = unflatten(vid, vid_shape)
return vid return vid