mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-01 14:33:30 +08:00
7b specific
This commit is contained in:
parent
e3fa1aa415
commit
ba59569e4b
@ -721,6 +721,7 @@ class NaSwinAttention(NaMMAttention):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.version_7b = kwargs.get("version", False)
|
||||
self.window = _triple(window)
|
||||
self.window_method = window_method
|
||||
assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window))
|
||||
@ -775,28 +776,32 @@ class NaSwinAttention(NaMMAttention):
|
||||
)
|
||||
|
||||
# window rope
|
||||
if self.rope:
|
||||
if self.rope.mm:
|
||||
# repeat text q and k for window mmrope
|
||||
_, num_h, _ = txt_q.shape
|
||||
txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)")
|
||||
txt_q_repeat = unflatten(txt_q_repeat, txt_shape)
|
||||
txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)]
|
||||
txt_q_repeat = list(chain(*txt_q_repeat))
|
||||
txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat)
|
||||
txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h)
|
||||
if not self.version_7b:
|
||||
if self.rope:
|
||||
if self.rope.mm:
|
||||
# repeat text q and k for window mmrope
|
||||
_, num_h, _ = txt_q.shape
|
||||
txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)")
|
||||
txt_q_repeat = unflatten(txt_q_repeat, txt_shape)
|
||||
txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)]
|
||||
txt_q_repeat = list(chain(*txt_q_repeat))
|
||||
txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat)
|
||||
txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h)
|
||||
|
||||
txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)")
|
||||
txt_k_repeat = unflatten(txt_k_repeat, txt_shape)
|
||||
txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)]
|
||||
txt_k_repeat = list(chain(*txt_k_repeat))
|
||||
txt_k_repeat, _ = flatten(txt_k_repeat)
|
||||
txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h)
|
||||
txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)")
|
||||
txt_k_repeat = unflatten(txt_k_repeat, txt_shape)
|
||||
txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)]
|
||||
txt_k_repeat = list(chain(*txt_k_repeat))
|
||||
txt_k_repeat, _ = flatten(txt_k_repeat)
|
||||
txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h)
|
||||
|
||||
vid_q, vid_k, txt_q, txt_k = self.rope(
|
||||
vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win
|
||||
)
|
||||
else:
|
||||
vid_q, vid_k, txt_q, txt_k = self.rope(
|
||||
vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win
|
||||
)
|
||||
else:
|
||||
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
||||
else:
|
||||
if self.rope:
|
||||
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
||||
|
||||
out = optimized_attention(
|
||||
@ -899,6 +904,7 @@ class NaMMSRTransformerBlock(nn.Module):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
version = kwargs.get("version", False)
|
||||
dim = MMArg(vid_dim, txt_dim)
|
||||
self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype)
|
||||
|
||||
@ -915,6 +921,7 @@ class NaMMSRTransformerBlock(nn.Module):
|
||||
shared_weights=shared_weights,
|
||||
window=kwargs.pop("window", None),
|
||||
window_method=kwargs.pop("window_method", None),
|
||||
version=version,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
|
||||
@ -929,6 +936,7 @@ class NaMMSRTransformerBlock(nn.Module):
|
||||
)
|
||||
self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype)
|
||||
self.is_last_layer = is_last_layer
|
||||
self.version = version
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1260,6 +1268,7 @@ class NaDiT(nn.Module):
|
||||
operations = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._7b_version = vid_dim == 3072
|
||||
self.dtype = dtype
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
|
||||
@ -1333,6 +1342,7 @@ class NaDiT(nn.Module):
|
||||
shared_weights=not (
|
||||
(i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i]
|
||||
),
|
||||
version = self._7b_version,
|
||||
operations = operations,
|
||||
**kwargs,
|
||||
**factory_kwargs
|
||||
|
||||
Loading…
Reference in New Issue
Block a user