From 44a5bf353af34f248b137ee7fcbad912b9f6c09b Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 7 Dec 2025 23:43:49 +0200 Subject: [PATCH] testing the model --- comfy/ldm/modules/attention.py | 2 +- comfy/ldm/seedvr/model.py | 178 ++++++++++++++++++++++----------- 2 files changed, 119 insertions(+), 61 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 256f9a989..35d2270ee 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -428,7 +428,7 @@ else: SDP_BATCH_LIMIT = 2**31 -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=Falsez): +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape else: diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index cf6287b03..86836468f 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -145,56 +145,77 @@ def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} -def make_720Pwindows(size, num_windows, shift = False): +def get_window_op(name: str): + if name == "720pwin_by_size_bysize": + return make_720Pwindows_bysize + if name == "720pswin_by_size_bysize": + return make_shifted_720Pwindows_bysize + raise ValueError(f"Unknown windowing method: {name}") + + +# -------------------------------- Windowing -------------------------------- # +def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): t, h, w = size resized_nt, resized_nh, resized_nw = num_windows - - scale = sqrt((45 * 80) / (h * w)) + #cal windows under 720p + scale = math.sqrt((45 * 80) / (h * w)) resized_h, resized_w = round(h * scale), round(w * scale) + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. + wt = ceil(min(t, 30) / resized_nt) # window size. + nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. + return [ + ( + slice(it * wt, min((it + 1) * wt, t)), + slice(ih * wh, min((ih + 1) * wh, h)), + slice(iw * ww, min((iw + 1) * ww, w)), + ) + for iw in range(nw) + if min((iw + 1) * ww, w) > iw * ww + for ih in range(nh) + if min((ih + 1) * wh, h) > ih * wh + for it in range(nt) + if min((it + 1) * wt, t) > it * wt + ] - wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) - wt = ceil(min(t, 30) / resized_nt) +def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): + t, h, w = size + resized_nt, resized_nh, resized_nw = num_windows + #cal windows under 720p + scale = math.sqrt((45 * 80) / (h * w)) + resized_h, resized_w = round(h * scale), round(w * scale) + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. + wt = ceil(min(t, 30) / resized_nt) # window size. - st, sh, sw = (0.5 * shift if wt < t else 0, - 0.5 * shift if wh < h else 0, - 0.5 * shift if ww < w else 0) - - nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) - if shift: - nt += 1 if st > 0 else 0 - nh += 1 if sh > 0 else 0 - nw += 1 if sw > 0 else 0 - - windows = [] - for iw in range(nw): - w_start = max(int((iw - sw) * ww), 0) - w_end = min(int((iw - sw + 1) * ww), w) - if w_end <= w_start: - continue - - for ih in range(nh): - h_start = max(int((ih - sh) * wh), 0) - h_end = min(int((ih - sh + 1) * wh), h) - if h_end <= h_start: - continue - - for it in range(nt): - t_start = max(int((it - st) * wt), 0) - t_end = min(int((it - st + 1) * wt), t) - if t_end <= t_start: - continue - - windows.append((slice(t_start, t_end), - slice(h_start, h_end), - slice(w_start, w_end))) - - return windows + st, sh, sw = ( # shift size. + 0.5 if wt < t else 0, + 0.5 if wh < h else 0, + 0.5 if ww < w else 0, + ) + nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. + nt, nh, nw = ( # number of window. + nt + 1 if st > 0 else 1, + nh + 1 if sh > 0 else 1, + nw + 1 if sw > 0 else 1, + ) + return [ + ( + slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)), + slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)), + slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)), + ) + for iw in range(nw) + if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0) + for ih in range(nh) + if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0) + for it in range(nt) + if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0) + ] class RotaryEmbedding(nn.Module): def __init__( self, dim, - custom_freqs, + custom_freqs = None, freqs_for = 'lang', theta = 10000, max_freq = 10, @@ -566,6 +587,7 @@ class NaMMAttention(nn.Module): ): super().__init__() dim = MMArg(vid_dim, txt_dim) + self.heads = heads inner_dim = heads * head_dim qkv_dim = inner_dim * 3 self.head_dim = head_dim @@ -575,19 +597,20 @@ class NaMMAttention(nn.Module): self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_weights) self.norm_q = MMModule( qk_norm, - dim=head_dim, + normalized_shape=head_dim, eps=qk_norm_eps, elementwise_affine=True, shared_weights=shared_weights, ) self.norm_k = MMModule( qk_norm, - dim=head_dim, + normalized_shape=head_dim, eps=qk_norm_eps, elementwise_affine=True, shared_weights=shared_weights, ) + self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) def forward( @@ -634,7 +657,7 @@ class NaMMAttention(nn.Module): _, unconcat = cache("mm_pnp", lambda: concat_idx(vid_len, txt_len)) - attn = optimized_attention(q, k, v, skip_reshape=True, skip_output_reshape=True) + attn = optimized_attention(q, k, v, heads = self.heads, skip_reshape=True, skip_output_reshape=True) attn = attn.flatten(0, 1) # to continue working with the rest of the code attn = rearrange(attn, "l h d -> l (h d)") @@ -682,7 +705,7 @@ class NaSwinAttention(NaMMAttention): self.window_method = window_method assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) - self.window_op = window_method + self.window_op = get_window_op(window_method) def forward( self, @@ -754,20 +777,17 @@ class NaSwinAttention(NaMMAttention): ) else: vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) - - out = self.attn( - q=concat_win(vid_q, txt_q).bfloat16(), - k=concat_win(vid_k, txt_k).bfloat16(), - v=concat_win(vid_v, txt_v).bfloat16(), - cu_seqlens_q=cache_win( - "vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() - ), - cu_seqlens_k=cache_win( - "vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() - ), - max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()), - max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()), - ).type_as(vid_q) + + # TODO: continue testing + b = len(vid_len_win) + vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)] + tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)] + + q = torch.cat([vq, tq], dim=1) + k = torch.cat([vk, tk], dim=1) + v = torch.cat([vv, tv], dim=1) + out = optimized_attention(q, k, v, heads=self.heads, skip_reshape=True, skip_output_reshape=True) + out = out.flatten(0, 1) # text pooling vid_out, txt_out = unconcat_win(out) @@ -847,7 +867,7 @@ class NaMMSRTransformerBlock(nn.Module): ): super().__init__() dim = MMArg(vid_dim, txt_dim) - self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) + self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) self.attn = NaSwinAttention( vid_dim=vid_dim, @@ -864,7 +884,7 @@ class NaMMSRTransformerBlock(nn.Module): window_method=kwargs.pop("window_method", None), ) - self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) + self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) self.mlp = MMModule( get_mlp(mlp_type), dim=dim, @@ -1155,6 +1175,7 @@ class NaDiT(nn.Module): txt_in_dim = 5120, heads = 20, head_dim = 128, + mm_layers = 10, expand_ratio = 4, qk_bias = False, patch_size = [ 1,2,2 ], @@ -1163,8 +1184,12 @@ class NaDiT(nn.Module): window_method: Optional[Tuple[str]] = None, temporal_window_size: int = None, temporal_shifted: bool = False, + rope_dim = 128, + rope_type = "mmrope3d", + vid_out_norm: Optional[str] = None, **kwargs, ): + window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] txt_dim = vid_dim emb_dim = vid_dim * 6 block_type = ["mmdit_sr"] * num_layers @@ -1202,6 +1227,7 @@ class NaDiT(nn.Module): if temporal_shifted is None or isinstance(temporal_shifted, bool): temporal_shifted = [temporal_shifted] * num_layers + rope_dim = rope_dim if rope_dim is not None else head_dim // 2 self.blocks = nn.ModuleList( [ NaMMSRTransformerBlock( @@ -1220,10 +1246,16 @@ class NaDiT(nn.Module): shared_qkv=shared_qkv, shared_mlp=shared_mlp, mlp_type=mlp_type, + rope_dim = rope_dim, window=window[i], window_method=window_method[i], temporal_window_size=temporal_window_size[i], temporal_shifted=temporal_shifted[i], + is_last_layer=(i == num_layers - 1), + rope_type = rope_type, + shared_weights=not ( + (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] + ), **kwargs, ) for i in range(num_layers) @@ -1241,6 +1273,20 @@ class NaDiT(nn.Module): "mmdit_stwin_3d_spatial", ] + self.vid_out_norm = None + if vid_out_norm is not None: + self.vid_out_norm = RMSNorm( + normalized_shape=vid_dim, + eps=norm_eps, + elementwise_affine=True, + ) + self.vid_out_ada = ada( + dim=vid_dim, + emb_dim=emb_dim, + layers=["out"], + modes=["in"], + ) + def forward( self, x, @@ -1284,6 +1330,18 @@ class NaDiT(nn.Module): cache=cache, ) + if self.vid_out_norm: + vid = self.vid_out_norm(vid) + vid = self.vid_out_ada( + vid, + emb=emb, + layer="out", + mode="in", + hid_len=cache("vid_len", lambda: vid_shape.prod(-1)), + cache=cache, + branch_tag="vid", + ) + vid, vid_shape = self.vid_out(vid, vid_shape, cache) vid = unflatten(vid, vid_shape) return vid