From ba59569e4b60c82fa50ba4344d05c62606c82a79 Mon Sep 17 00:00:00 2001 From: "Yousef R. Gamaleldin" Date: Fri, 23 Jan 2026 01:24:15 +0200 Subject: [PATCH] 7b specific --- comfy/ldm/seedvr/model.py | 50 +++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index e7570699e..21f16bc4b 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -721,6 +721,7 @@ class NaSwinAttention(NaMMAttention): **kwargs, ): super().__init__(*args, **kwargs) + self.version_7b = kwargs.get("version", False) self.window = _triple(window) self.window_method = window_method assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) @@ -775,28 +776,32 @@ class NaSwinAttention(NaMMAttention): ) # window rope - if self.rope: - if self.rope.mm: - # repeat text q and k for window mmrope - _, num_h, _ = txt_q.shape - txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)") - txt_q_repeat = unflatten(txt_q_repeat, txt_shape) - txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] - txt_q_repeat = list(chain(*txt_q_repeat)) - txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat) - txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h) + if not self.version_7b: + if self.rope: + if self.rope.mm: + # repeat text q and k for window mmrope + _, num_h, _ = txt_q.shape + txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)") + txt_q_repeat = unflatten(txt_q_repeat, txt_shape) + txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] + txt_q_repeat = list(chain(*txt_q_repeat)) + txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat) + txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h) - txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)") - txt_k_repeat = unflatten(txt_k_repeat, txt_shape) - txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] - txt_k_repeat = list(chain(*txt_k_repeat)) - txt_k_repeat, _ = flatten(txt_k_repeat) - txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h) + txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)") + txt_k_repeat = unflatten(txt_k_repeat, txt_shape) + txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] + txt_k_repeat = list(chain(*txt_k_repeat)) + txt_k_repeat, _ = flatten(txt_k_repeat) + txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h) - vid_q, vid_k, txt_q, txt_k = self.rope( - vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win - ) - else: + vid_q, vid_k, txt_q, txt_k = self.rope( + vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win + ) + else: + vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) + else: + if self.rope: vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) out = optimized_attention( @@ -899,6 +904,7 @@ class NaMMSRTransformerBlock(nn.Module): **kwargs, ): super().__init__() + version = kwargs.get("version", False) dim = MMArg(vid_dim, txt_dim) self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype) @@ -915,6 +921,7 @@ class NaMMSRTransformerBlock(nn.Module): shared_weights=shared_weights, window=kwargs.pop("window", None), window_method=kwargs.pop("window_method", None), + version=version, device=device, dtype=dtype, operations=operations ) @@ -929,6 +936,7 @@ class NaMMSRTransformerBlock(nn.Module): ) self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) self.is_last_layer = is_last_layer + self.version = version def forward( self, @@ -1260,6 +1268,7 @@ class NaDiT(nn.Module): operations = None, **kwargs, ): + self._7b_version = vid_dim == 3072 self.dtype = dtype factory_kwargs = {"device": device, "dtype": dtype} window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] @@ -1333,6 +1342,7 @@ class NaDiT(nn.Module): shared_weights=not ( (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] ), + version = self._7b_version, operations = operations, **kwargs, **factory_kwargs