mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-05 00:06:38 +08:00
7b specific
This commit is contained in:
parent
e3fa1aa415
commit
ba59569e4b
@ -721,6 +721,7 @@ class NaSwinAttention(NaMMAttention):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
self.version_7b = kwargs.get("version", False)
|
||||||
self.window = _triple(window)
|
self.window = _triple(window)
|
||||||
self.window_method = window_method
|
self.window_method = window_method
|
||||||
assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window))
|
assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window))
|
||||||
@ -775,28 +776,32 @@ class NaSwinAttention(NaMMAttention):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# window rope
|
# window rope
|
||||||
if self.rope:
|
if not self.version_7b:
|
||||||
if self.rope.mm:
|
if self.rope:
|
||||||
# repeat text q and k for window mmrope
|
if self.rope.mm:
|
||||||
_, num_h, _ = txt_q.shape
|
# repeat text q and k for window mmrope
|
||||||
txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)")
|
_, num_h, _ = txt_q.shape
|
||||||
txt_q_repeat = unflatten(txt_q_repeat, txt_shape)
|
txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)")
|
||||||
txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)]
|
txt_q_repeat = unflatten(txt_q_repeat, txt_shape)
|
||||||
txt_q_repeat = list(chain(*txt_q_repeat))
|
txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)]
|
||||||
txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat)
|
txt_q_repeat = list(chain(*txt_q_repeat))
|
||||||
txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h)
|
txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat)
|
||||||
|
txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h)
|
||||||
|
|
||||||
txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)")
|
txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)")
|
||||||
txt_k_repeat = unflatten(txt_k_repeat, txt_shape)
|
txt_k_repeat = unflatten(txt_k_repeat, txt_shape)
|
||||||
txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)]
|
txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)]
|
||||||
txt_k_repeat = list(chain(*txt_k_repeat))
|
txt_k_repeat = list(chain(*txt_k_repeat))
|
||||||
txt_k_repeat, _ = flatten(txt_k_repeat)
|
txt_k_repeat, _ = flatten(txt_k_repeat)
|
||||||
txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h)
|
txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h)
|
||||||
|
|
||||||
vid_q, vid_k, txt_q, txt_k = self.rope(
|
vid_q, vid_k, txt_q, txt_k = self.rope(
|
||||||
vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win
|
vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
||||||
|
else:
|
||||||
|
if self.rope:
|
||||||
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
||||||
|
|
||||||
out = optimized_attention(
|
out = optimized_attention(
|
||||||
@ -899,6 +904,7 @@ class NaMMSRTransformerBlock(nn.Module):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
version = kwargs.get("version", False)
|
||||||
dim = MMArg(vid_dim, txt_dim)
|
dim = MMArg(vid_dim, txt_dim)
|
||||||
self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype)
|
self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype)
|
||||||
|
|
||||||
@ -915,6 +921,7 @@ class NaMMSRTransformerBlock(nn.Module):
|
|||||||
shared_weights=shared_weights,
|
shared_weights=shared_weights,
|
||||||
window=kwargs.pop("window", None),
|
window=kwargs.pop("window", None),
|
||||||
window_method=kwargs.pop("window_method", None),
|
window_method=kwargs.pop("window_method", None),
|
||||||
|
version=version,
|
||||||
device=device, dtype=dtype, operations=operations
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -929,6 +936,7 @@ class NaMMSRTransformerBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype)
|
self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype)
|
||||||
self.is_last_layer = is_last_layer
|
self.is_last_layer = is_last_layer
|
||||||
|
self.version = version
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -1260,6 +1268,7 @@ class NaDiT(nn.Module):
|
|||||||
operations = None,
|
operations = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
self._7b_version = vid_dim == 3072
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
|
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
|
||||||
@ -1333,6 +1342,7 @@ class NaDiT(nn.Module):
|
|||||||
shared_weights=not (
|
shared_weights=not (
|
||||||
(i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i]
|
(i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i]
|
||||||
),
|
),
|
||||||
|
version = self._7b_version,
|
||||||
operations = operations,
|
operations = operations,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
**factory_kwargs
|
**factory_kwargs
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user