mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 13:00:54 +08:00
testing the model
This commit is contained in:
parent
4b9332cc21
commit
44a5bf353a
@ -428,7 +428,7 @@ else:
|
|||||||
SDP_BATCH_LIMIT = 2**31
|
SDP_BATCH_LIMIT = 2**31
|
||||||
|
|
||||||
|
|
||||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=Falsez):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -145,56 +145,77 @@ def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()}
|
return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()}
|
||||||
|
|
||||||
|
|
||||||
def make_720Pwindows(size, num_windows, shift = False):
|
def get_window_op(name: str):
|
||||||
|
if name == "720pwin_by_size_bysize":
|
||||||
|
return make_720Pwindows_bysize
|
||||||
|
if name == "720pswin_by_size_bysize":
|
||||||
|
return make_shifted_720Pwindows_bysize
|
||||||
|
raise ValueError(f"Unknown windowing method: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------- Windowing -------------------------------- #
|
||||||
|
def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]):
|
||||||
t, h, w = size
|
t, h, w = size
|
||||||
resized_nt, resized_nh, resized_nw = num_windows
|
resized_nt, resized_nh, resized_nw = num_windows
|
||||||
|
#cal windows under 720p
|
||||||
scale = sqrt((45 * 80) / (h * w))
|
scale = math.sqrt((45 * 80) / (h * w))
|
||||||
resized_h, resized_w = round(h * scale), round(w * scale)
|
resized_h, resized_w = round(h * scale), round(w * scale)
|
||||||
|
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
|
||||||
|
wt = ceil(min(t, 30) / resized_nt) # window size.
|
||||||
|
nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size.
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
slice(it * wt, min((it + 1) * wt, t)),
|
||||||
|
slice(ih * wh, min((ih + 1) * wh, h)),
|
||||||
|
slice(iw * ww, min((iw + 1) * ww, w)),
|
||||||
|
)
|
||||||
|
for iw in range(nw)
|
||||||
|
if min((iw + 1) * ww, w) > iw * ww
|
||||||
|
for ih in range(nh)
|
||||||
|
if min((ih + 1) * wh, h) > ih * wh
|
||||||
|
for it in range(nt)
|
||||||
|
if min((it + 1) * wt, t) > it * wt
|
||||||
|
]
|
||||||
|
|
||||||
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw)
|
def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]):
|
||||||
wt = ceil(min(t, 30) / resized_nt)
|
t, h, w = size
|
||||||
|
resized_nt, resized_nh, resized_nw = num_windows
|
||||||
|
#cal windows under 720p
|
||||||
|
scale = math.sqrt((45 * 80) / (h * w))
|
||||||
|
resized_h, resized_w = round(h * scale), round(w * scale)
|
||||||
|
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
|
||||||
|
wt = ceil(min(t, 30) / resized_nt) # window size.
|
||||||
|
|
||||||
st, sh, sw = (0.5 * shift if wt < t else 0,
|
st, sh, sw = ( # shift size.
|
||||||
0.5 * shift if wh < h else 0,
|
0.5 if wt < t else 0,
|
||||||
0.5 * shift if ww < w else 0)
|
0.5 if wh < h else 0,
|
||||||
|
0.5 if ww < w else 0,
|
||||||
nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww)
|
)
|
||||||
if shift:
|
nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size.
|
||||||
nt += 1 if st > 0 else 0
|
nt, nh, nw = ( # number of window.
|
||||||
nh += 1 if sh > 0 else 0
|
nt + 1 if st > 0 else 1,
|
||||||
nw += 1 if sw > 0 else 0
|
nh + 1 if sh > 0 else 1,
|
||||||
|
nw + 1 if sw > 0 else 1,
|
||||||
windows = []
|
)
|
||||||
for iw in range(nw):
|
return [
|
||||||
w_start = max(int((iw - sw) * ww), 0)
|
(
|
||||||
w_end = min(int((iw - sw + 1) * ww), w)
|
slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)),
|
||||||
if w_end <= w_start:
|
slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)),
|
||||||
continue
|
slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)),
|
||||||
|
)
|
||||||
for ih in range(nh):
|
for iw in range(nw)
|
||||||
h_start = max(int((ih - sh) * wh), 0)
|
if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0)
|
||||||
h_end = min(int((ih - sh + 1) * wh), h)
|
for ih in range(nh)
|
||||||
if h_end <= h_start:
|
if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0)
|
||||||
continue
|
for it in range(nt)
|
||||||
|
if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0)
|
||||||
for it in range(nt):
|
]
|
||||||
t_start = max(int((it - st) * wt), 0)
|
|
||||||
t_end = min(int((it - st + 1) * wt), t)
|
|
||||||
if t_end <= t_start:
|
|
||||||
continue
|
|
||||||
|
|
||||||
windows.append((slice(t_start, t_end),
|
|
||||||
slice(h_start, h_end),
|
|
||||||
slice(w_start, w_end)))
|
|
||||||
|
|
||||||
return windows
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
class RotaryEmbedding(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
custom_freqs,
|
custom_freqs = None,
|
||||||
freqs_for = 'lang',
|
freqs_for = 'lang',
|
||||||
theta = 10000,
|
theta = 10000,
|
||||||
max_freq = 10,
|
max_freq = 10,
|
||||||
@ -566,6 +587,7 @@ class NaMMAttention(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dim = MMArg(vid_dim, txt_dim)
|
dim = MMArg(vid_dim, txt_dim)
|
||||||
|
self.heads = heads
|
||||||
inner_dim = heads * head_dim
|
inner_dim = heads * head_dim
|
||||||
qkv_dim = inner_dim * 3
|
qkv_dim = inner_dim * 3
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
@ -575,19 +597,20 @@ class NaMMAttention(nn.Module):
|
|||||||
self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_weights)
|
self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_weights)
|
||||||
self.norm_q = MMModule(
|
self.norm_q = MMModule(
|
||||||
qk_norm,
|
qk_norm,
|
||||||
dim=head_dim,
|
normalized_shape=head_dim,
|
||||||
eps=qk_norm_eps,
|
eps=qk_norm_eps,
|
||||||
elementwise_affine=True,
|
elementwise_affine=True,
|
||||||
shared_weights=shared_weights,
|
shared_weights=shared_weights,
|
||||||
)
|
)
|
||||||
self.norm_k = MMModule(
|
self.norm_k = MMModule(
|
||||||
qk_norm,
|
qk_norm,
|
||||||
dim=head_dim,
|
normalized_shape=head_dim,
|
||||||
eps=qk_norm_eps,
|
eps=qk_norm_eps,
|
||||||
elementwise_affine=True,
|
elementwise_affine=True,
|
||||||
shared_weights=shared_weights,
|
shared_weights=shared_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim)
|
self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -634,7 +657,7 @@ class NaMMAttention(nn.Module):
|
|||||||
|
|
||||||
_, unconcat = cache("mm_pnp", lambda: concat_idx(vid_len, txt_len))
|
_, unconcat = cache("mm_pnp", lambda: concat_idx(vid_len, txt_len))
|
||||||
|
|
||||||
attn = optimized_attention(q, k, v, skip_reshape=True, skip_output_reshape=True)
|
attn = optimized_attention(q, k, v, heads = self.heads, skip_reshape=True, skip_output_reshape=True)
|
||||||
attn = attn.flatten(0, 1) # to continue working with the rest of the code
|
attn = attn.flatten(0, 1) # to continue working with the rest of the code
|
||||||
|
|
||||||
attn = rearrange(attn, "l h d -> l (h d)")
|
attn = rearrange(attn, "l h d -> l (h d)")
|
||||||
@ -682,7 +705,7 @@ class NaSwinAttention(NaMMAttention):
|
|||||||
self.window_method = window_method
|
self.window_method = window_method
|
||||||
assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window))
|
assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window))
|
||||||
|
|
||||||
self.window_op = window_method
|
self.window_op = get_window_op(window_method)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -754,20 +777,17 @@ class NaSwinAttention(NaMMAttention):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
||||||
|
|
||||||
out = self.attn(
|
# TODO: continue testing
|
||||||
q=concat_win(vid_q, txt_q).bfloat16(),
|
b = len(vid_len_win)
|
||||||
k=concat_win(vid_k, txt_k).bfloat16(),
|
vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)]
|
||||||
v=concat_win(vid_v, txt_v).bfloat16(),
|
tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)]
|
||||||
cu_seqlens_q=cache_win(
|
|
||||||
"vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
|
q = torch.cat([vq, tq], dim=1)
|
||||||
),
|
k = torch.cat([vk, tk], dim=1)
|
||||||
cu_seqlens_k=cache_win(
|
v = torch.cat([vv, tv], dim=1)
|
||||||
"vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
|
out = optimized_attention(q, k, v, heads=self.heads, skip_reshape=True, skip_output_reshape=True)
|
||||||
),
|
out = out.flatten(0, 1)
|
||||||
max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()),
|
|
||||||
max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()),
|
|
||||||
).type_as(vid_q)
|
|
||||||
|
|
||||||
# text pooling
|
# text pooling
|
||||||
vid_out, txt_out = unconcat_win(out)
|
vid_out, txt_out = unconcat_win(out)
|
||||||
@ -847,7 +867,7 @@ class NaMMSRTransformerBlock(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dim = MMArg(vid_dim, txt_dim)
|
dim = MMArg(vid_dim, txt_dim)
|
||||||
self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,)
|
self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,)
|
||||||
|
|
||||||
self.attn = NaSwinAttention(
|
self.attn = NaSwinAttention(
|
||||||
vid_dim=vid_dim,
|
vid_dim=vid_dim,
|
||||||
@ -864,7 +884,7 @@ class NaMMSRTransformerBlock(nn.Module):
|
|||||||
window_method=kwargs.pop("window_method", None),
|
window_method=kwargs.pop("window_method", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer)
|
self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer)
|
||||||
self.mlp = MMModule(
|
self.mlp = MMModule(
|
||||||
get_mlp(mlp_type),
|
get_mlp(mlp_type),
|
||||||
dim=dim,
|
dim=dim,
|
||||||
@ -1155,6 +1175,7 @@ class NaDiT(nn.Module):
|
|||||||
txt_in_dim = 5120,
|
txt_in_dim = 5120,
|
||||||
heads = 20,
|
heads = 20,
|
||||||
head_dim = 128,
|
head_dim = 128,
|
||||||
|
mm_layers = 10,
|
||||||
expand_ratio = 4,
|
expand_ratio = 4,
|
||||||
qk_bias = False,
|
qk_bias = False,
|
||||||
patch_size = [ 1,2,2 ],
|
patch_size = [ 1,2,2 ],
|
||||||
@ -1163,8 +1184,12 @@ class NaDiT(nn.Module):
|
|||||||
window_method: Optional[Tuple[str]] = None,
|
window_method: Optional[Tuple[str]] = None,
|
||||||
temporal_window_size: int = None,
|
temporal_window_size: int = None,
|
||||||
temporal_shifted: bool = False,
|
temporal_shifted: bool = False,
|
||||||
|
rope_dim = 128,
|
||||||
|
rope_type = "mmrope3d",
|
||||||
|
vid_out_norm: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
|
||||||
txt_dim = vid_dim
|
txt_dim = vid_dim
|
||||||
emb_dim = vid_dim * 6
|
emb_dim = vid_dim * 6
|
||||||
block_type = ["mmdit_sr"] * num_layers
|
block_type = ["mmdit_sr"] * num_layers
|
||||||
@ -1202,6 +1227,7 @@ class NaDiT(nn.Module):
|
|||||||
if temporal_shifted is None or isinstance(temporal_shifted, bool):
|
if temporal_shifted is None or isinstance(temporal_shifted, bool):
|
||||||
temporal_shifted = [temporal_shifted] * num_layers
|
temporal_shifted = [temporal_shifted] * num_layers
|
||||||
|
|
||||||
|
rope_dim = rope_dim if rope_dim is not None else head_dim // 2
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
NaMMSRTransformerBlock(
|
NaMMSRTransformerBlock(
|
||||||
@ -1220,10 +1246,16 @@ class NaDiT(nn.Module):
|
|||||||
shared_qkv=shared_qkv,
|
shared_qkv=shared_qkv,
|
||||||
shared_mlp=shared_mlp,
|
shared_mlp=shared_mlp,
|
||||||
mlp_type=mlp_type,
|
mlp_type=mlp_type,
|
||||||
|
rope_dim = rope_dim,
|
||||||
window=window[i],
|
window=window[i],
|
||||||
window_method=window_method[i],
|
window_method=window_method[i],
|
||||||
temporal_window_size=temporal_window_size[i],
|
temporal_window_size=temporal_window_size[i],
|
||||||
temporal_shifted=temporal_shifted[i],
|
temporal_shifted=temporal_shifted[i],
|
||||||
|
is_last_layer=(i == num_layers - 1),
|
||||||
|
rope_type = rope_type,
|
||||||
|
shared_weights=not (
|
||||||
|
(i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i]
|
||||||
|
),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
@ -1241,6 +1273,20 @@ class NaDiT(nn.Module):
|
|||||||
"mmdit_stwin_3d_spatial",
|
"mmdit_stwin_3d_spatial",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.vid_out_norm = None
|
||||||
|
if vid_out_norm is not None:
|
||||||
|
self.vid_out_norm = RMSNorm(
|
||||||
|
normalized_shape=vid_dim,
|
||||||
|
eps=norm_eps,
|
||||||
|
elementwise_affine=True,
|
||||||
|
)
|
||||||
|
self.vid_out_ada = ada(
|
||||||
|
dim=vid_dim,
|
||||||
|
emb_dim=emb_dim,
|
||||||
|
layers=["out"],
|
||||||
|
modes=["in"],
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@ -1284,6 +1330,18 @@ class NaDiT(nn.Module):
|
|||||||
cache=cache,
|
cache=cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.vid_out_norm:
|
||||||
|
vid = self.vid_out_norm(vid)
|
||||||
|
vid = self.vid_out_ada(
|
||||||
|
vid,
|
||||||
|
emb=emb,
|
||||||
|
layer="out",
|
||||||
|
mode="in",
|
||||||
|
hid_len=cache("vid_len", lambda: vid_shape.prod(-1)),
|
||||||
|
cache=cache,
|
||||||
|
branch_tag="vid",
|
||||||
|
)
|
||||||
|
|
||||||
vid, vid_shape = self.vid_out(vid, vid_shape, cache)
|
vid, vid_shape = self.vid_out(vid, vid_shape, cache)
|
||||||
vid = unflatten(vid, vid_shape)
|
vid = unflatten(vid, vid_shape)
|
||||||
return vid
|
return vid
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user