mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
final changes
This commit is contained in:
parent
7b2e5ef0af
commit
21bc67d7db
@ -526,22 +526,22 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
|
|||||||
max_height = 0
|
max_height = 0
|
||||||
max_width = 0
|
max_width = 0
|
||||||
max_txt_len = 0
|
max_txt_len = 0
|
||||||
|
|
||||||
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
|
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
|
||||||
max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal
|
max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal
|
||||||
max_height = max(max_height, h)
|
max_height = max(max_height, h)
|
||||||
max_width = max(max_width, w)
|
max_width = max(max_width, w)
|
||||||
max_txt_len = max(max_txt_len, l)
|
max_txt_len = max(max_txt_len, l)
|
||||||
|
|
||||||
# Compute frequencies for actual max dimensions needed
|
# Compute frequencies for actual max dimensions needed
|
||||||
# Add small buffer to improve cache hits across similar batches
|
# Add small buffer to improve cache hits across similar batches
|
||||||
vid_freqs = self.get_axial_freqs(
|
vid_freqs = self.get_axial_freqs(
|
||||||
min(max_temporal + 16, 1024), # Cap at 1024, add small buffer
|
min(max_temporal + 16, 1024), # Cap at 1024, add small buffer
|
||||||
min(max_height + 4, 128), # Cap at 128, add small buffer
|
min(max_height + 4, 128), # Cap at 128, add small buffer
|
||||||
min(max_width + 4, 128) # Cap at 128, add small buffer
|
min(max_width + 4, 128) # Cap at 128, add small buffer
|
||||||
)
|
)
|
||||||
txt_freqs = self.get_axial_freqs(min(max_txt_len + 16, 1024))
|
txt_freqs = self.get_axial_freqs(min(max_txt_len + 16, 1024))
|
||||||
|
|
||||||
# Now slice as before
|
# Now slice as before
|
||||||
vid_freq_list, txt_freq_list = [], []
|
vid_freq_list, txt_freq_list = [], []
|
||||||
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
|
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
|
||||||
@ -615,6 +615,7 @@ class NaMMAttention(nn.Module):
|
|||||||
rope_type: Optional[str],
|
rope_type: Optional[str],
|
||||||
rope_dim: int,
|
rope_dim: int,
|
||||||
shared_weights: bool,
|
shared_weights: bool,
|
||||||
|
device, dtype, operations,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -624,15 +625,16 @@ class NaMMAttention(nn.Module):
|
|||||||
qkv_dim = inner_dim * 3
|
qkv_dim = inner_dim * 3
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.proj_qkv = MMModule(
|
self.proj_qkv = MMModule(
|
||||||
nn.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights
|
operations.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights, device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_weights)
|
self.proj_out = MMModule(operations.Linear, inner_dim, dim, shared_weights=shared_weights, device=device, dtype=dtype)
|
||||||
self.norm_q = MMModule(
|
self.norm_q = MMModule(
|
||||||
qk_norm,
|
qk_norm,
|
||||||
normalized_shape=head_dim,
|
normalized_shape=head_dim,
|
||||||
eps=qk_norm_eps,
|
eps=qk_norm_eps,
|
||||||
elementwise_affine=True,
|
elementwise_affine=True,
|
||||||
shared_weights=shared_weights,
|
shared_weights=shared_weights,
|
||||||
|
device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
self.norm_k = MMModule(
|
self.norm_k = MMModule(
|
||||||
qk_norm,
|
qk_norm,
|
||||||
@ -640,6 +642,7 @@ class NaMMAttention(nn.Module):
|
|||||||
eps=qk_norm_eps,
|
eps=qk_norm_eps,
|
||||||
elementwise_affine=True,
|
elementwise_affine=True,
|
||||||
shared_weights=shared_weights,
|
shared_weights=shared_weights,
|
||||||
|
device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -795,11 +798,12 @@ class MLP(nn.Module):
|
|||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
expand_ratio: int,
|
expand_ratio: int,
|
||||||
|
device, dtype, operations
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.proj_in = nn.Linear(dim, dim * expand_ratio)
|
self.proj_in = operations.Linear(dim, dim * expand_ratio, device=device, dtype=dtype)
|
||||||
self.act = nn.GELU("tanh")
|
self.act = nn.GELU("tanh")
|
||||||
self.proj_out = nn.Linear(dim * expand_ratio, dim)
|
self.proj_out = operations.Linear(dim * expand_ratio, dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
@ -814,13 +818,14 @@ class SwiGLUMLP(nn.Module):
|
|||||||
dim: int,
|
dim: int,
|
||||||
expand_ratio: int,
|
expand_ratio: int,
|
||||||
multiple_of: int = 256,
|
multiple_of: int = 256,
|
||||||
|
device=None, dtype=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_dim = int(2 * dim * expand_ratio / 3)
|
hidden_dim = int(2 * dim * expand_ratio / 3)
|
||||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||||
self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False)
|
self.proj_in_gate = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype)
|
||||||
self.proj_out = nn.Linear(hidden_dim, dim, bias=False)
|
self.proj_out = operations.Linear(hidden_dim, dim, bias=False, device=device, dtype=dtype)
|
||||||
self.proj_in = nn.Linear(dim, hidden_dim, bias=False)
|
self.proj_in = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
x = x.to(next(self.proj_in.parameters()).device)
|
x = x.to(next(self.proj_in.parameters()).device)
|
||||||
@ -855,11 +860,12 @@ class NaMMSRTransformerBlock(nn.Module):
|
|||||||
rope_type: str,
|
rope_type: str,
|
||||||
rope_dim: int,
|
rope_dim: int,
|
||||||
is_last_layer: bool,
|
is_last_layer: bool,
|
||||||
|
device, dtype, operations,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dim = MMArg(vid_dim, txt_dim)
|
dim = MMArg(vid_dim, txt_dim)
|
||||||
self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,)
|
self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.attn = NaSwinAttention(
|
self.attn = NaSwinAttention(
|
||||||
vid_dim=vid_dim,
|
vid_dim=vid_dim,
|
||||||
@ -874,17 +880,19 @@ class NaMMSRTransformerBlock(nn.Module):
|
|||||||
shared_weights=shared_weights,
|
shared_weights=shared_weights,
|
||||||
window=kwargs.pop("window", None),
|
window=kwargs.pop("window", None),
|
||||||
window_method=kwargs.pop("window_method", None),
|
window_method=kwargs.pop("window_method", None),
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer)
|
self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype)
|
||||||
self.mlp = MMModule(
|
self.mlp = MMModule(
|
||||||
get_mlp(mlp_type),
|
get_mlp(mlp_type),
|
||||||
dim=dim,
|
dim=dim,
|
||||||
expand_ratio=expand_ratio,
|
expand_ratio=expand_ratio,
|
||||||
shared_weights=shared_weights,
|
shared_weights=shared_weights,
|
||||||
vid_only=is_last_layer
|
vid_only=is_last_layer,
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer)
|
self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype)
|
||||||
self.is_last_layer = is_last_layer
|
self.is_last_layer = is_last_layer
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -933,11 +941,12 @@ class PatchOut(nn.Module):
|
|||||||
out_channels: int,
|
out_channels: int,
|
||||||
patch_size: Union[int, Tuple[int, int, int]],
|
patch_size: Union[int, Tuple[int, int, int]],
|
||||||
dim: int,
|
dim: int,
|
||||||
|
device, dtype, operations
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
t, h, w = _triple(patch_size)
|
t, h, w = _triple(patch_size)
|
||||||
self.patch_size = t, h, w
|
self.patch_size = t, h, w
|
||||||
self.proj = nn.Linear(dim, out_channels * t * h * w)
|
self.proj = operations.Linear(dim, out_channels * t * h * w, device=device, dtype=dtype)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -981,11 +990,12 @@ class PatchIn(nn.Module):
|
|||||||
in_channels: int,
|
in_channels: int,
|
||||||
patch_size: Union[int, Tuple[int, int, int]],
|
patch_size: Union[int, Tuple[int, int, int]],
|
||||||
dim: int,
|
dim: int,
|
||||||
|
device, dtype, operations
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
t, h, w = _triple(patch_size)
|
t, h, w = _triple(patch_size)
|
||||||
self.patch_size = t, h, w
|
self.patch_size = t, h, w
|
||||||
self.proj = nn.Linear(in_channels * t * h * w, dim)
|
self.proj = operations.Linear(in_channels * t * h * w, dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -1033,6 +1043,7 @@ class AdaSingle(nn.Module):
|
|||||||
emb_dim: int,
|
emb_dim: int,
|
||||||
layers: List[str],
|
layers: List[str],
|
||||||
modes: List[str] = ["in", "out"],
|
modes: List[str] = ["in", "out"],
|
||||||
|
device = None, dtype = None,
|
||||||
):
|
):
|
||||||
assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim"
|
assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim"
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -1041,12 +1052,12 @@ class AdaSingle(nn.Module):
|
|||||||
self.layers = layers
|
self.layers = layers
|
||||||
for l in layers:
|
for l in layers:
|
||||||
if "in" in modes:
|
if "in" in modes:
|
||||||
self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5))
|
self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim, device=device, dtype=dtype) / dim**0.5))
|
||||||
self.register_parameter(
|
self.register_parameter(
|
||||||
f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1)
|
f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1)
|
||||||
)
|
)
|
||||||
if "out" in modes:
|
if "out" in modes:
|
||||||
self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5))
|
self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim, device=device, dtype=dtype) / dim**0.5))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -1096,12 +1107,13 @@ class TimeEmbedding(nn.Module):
|
|||||||
sinusoidal_dim: int,
|
sinusoidal_dim: int,
|
||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
output_dim: int,
|
output_dim: int,
|
||||||
|
device, dtype, operations
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.sinusoidal_dim = sinusoidal_dim
|
self.sinusoidal_dim = sinusoidal_dim
|
||||||
self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim)
|
self.proj_in = operations.Linear(sinusoidal_dim, hidden_dim, device=device, dtype=dtype)
|
||||||
self.proj_hid = nn.Linear(hidden_dim, hidden_dim)
|
self.proj_hid = operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype)
|
||||||
self.proj_out = nn.Linear(hidden_dim, output_dim)
|
self.proj_out = operations.Linear(hidden_dim, output_dim, device=device, dtype=dtype)
|
||||||
self.act = nn.SiLU()
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -1199,6 +1211,7 @@ class NaDiT(nn.Module):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
|
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
|
||||||
txt_dim = vid_dim
|
txt_dim = vid_dim
|
||||||
emb_dim = vid_dim * 6
|
emb_dim = vid_dim * 6
|
||||||
@ -1212,15 +1225,16 @@ class NaDiT(nn.Module):
|
|||||||
elif len(block_type) != num_layers:
|
elif len(block_type) != num_layers:
|
||||||
raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
|
raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.register_buffer("positive_conditioning", torch.empty((58, 5120)))
|
self.register_buffer("positive_conditioning", torch.empty((58, 5120), device=device, dtype=dtype))
|
||||||
self.register_buffer("negative_conditioning", torch.empty((64, 5120)))
|
self.register_buffer("negative_conditioning", torch.empty((64, 5120), device=device, dtype=dtype))
|
||||||
self.vid_in = NaPatchIn(
|
self.vid_in = NaPatchIn(
|
||||||
in_channels=vid_in_channels,
|
in_channels=vid_in_channels,
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
dim=vid_dim,
|
dim=vid_dim,
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
self.txt_in = (
|
self.txt_in = (
|
||||||
nn.Linear(txt_in_dim, txt_dim)
|
operations.Linear(txt_in_dim, txt_dim, **factory_kwargs)
|
||||||
if txt_in_dim and txt_in_dim != txt_dim
|
if txt_in_dim and txt_in_dim != txt_dim
|
||||||
else nn.Identity()
|
else nn.Identity()
|
||||||
)
|
)
|
||||||
@ -1228,6 +1242,7 @@ class NaDiT(nn.Module):
|
|||||||
sinusoidal_dim=256,
|
sinusoidal_dim=256,
|
||||||
hidden_dim=max(vid_dim, txt_dim),
|
hidden_dim=max(vid_dim, txt_dim),
|
||||||
output_dim=emb_dim,
|
output_dim=emb_dim,
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
if window is None or isinstance(window[0], int):
|
if window is None or isinstance(window[0], int):
|
||||||
@ -1268,7 +1283,9 @@ class NaDiT(nn.Module):
|
|||||||
shared_weights=not (
|
shared_weights=not (
|
||||||
(i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i]
|
(i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i]
|
||||||
),
|
),
|
||||||
|
operations = operations,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
**factory_kwargs
|
||||||
)
|
)
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
]
|
]
|
||||||
@ -1277,6 +1294,7 @@ class NaDiT(nn.Module):
|
|||||||
out_channels=vid_out_channels,
|
out_channels=vid_out_channels,
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
dim=vid_dim,
|
dim=vid_dim,
|
||||||
|
device=device, dtype=dtype, operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
self.need_txt_repeat = block_type[0] in [
|
self.need_txt_repeat = block_type[0] in [
|
||||||
@ -1291,12 +1309,14 @@ class NaDiT(nn.Module):
|
|||||||
normalized_shape=vid_dim,
|
normalized_shape=vid_dim,
|
||||||
eps=norm_eps,
|
eps=norm_eps,
|
||||||
elementwise_affine=True,
|
elementwise_affine=True,
|
||||||
|
device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
self.vid_out_ada = ada(
|
self.vid_out_ada = ada(
|
||||||
dim=vid_dim,
|
dim=vid_dim,
|
||||||
emb_dim=emb_dim,
|
emb_dim=emb_dim,
|
||||||
layers=["out"],
|
layers=["out"],
|
||||||
modes=["in"],
|
modes=["in"],
|
||||||
|
device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
self.stop_cfg_index = -1
|
self.stop_cfg_index = -1
|
||||||
|
|||||||
@ -16,6 +16,9 @@ import math
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND
|
from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
_NORM_LIMIT = float("inf")
|
_NORM_LIMIT = float("inf")
|
||||||
|
|
||||||
|
|
||||||
@ -89,9 +92,9 @@ class SpatialNorm(nn.Module):
|
|||||||
zq_channels: int,
|
zq_channels: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
||||||
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
self.conv_y = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
self.conv_b = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
|
||||||
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
||||||
f_size = f.shape[-2:]
|
f_size = f.shape[-2:]
|
||||||
@ -164,7 +167,7 @@ class Attention(nn.Module):
|
|||||||
self.only_cross_attention = only_cross_attention
|
self.only_cross_attention = only_cross_attention
|
||||||
|
|
||||||
if norm_num_groups is not None:
|
if norm_num_groups is not None:
|
||||||
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
self.group_norm = ops.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
||||||
else:
|
else:
|
||||||
self.group_norm = None
|
self.group_norm = None
|
||||||
|
|
||||||
@ -177,22 +180,22 @@ class Attention(nn.Module):
|
|||||||
self.norm_k = None
|
self.norm_k = None
|
||||||
|
|
||||||
self.norm_cross = None
|
self.norm_cross = None
|
||||||
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
self.to_q = ops.Linear(query_dim, self.inner_dim, bias=bias)
|
||||||
|
|
||||||
if not self.only_cross_attention:
|
if not self.only_cross_attention:
|
||||||
# only relevant for the `AddedKVProcessor` classes
|
# only relevant for the `AddedKVProcessor` classes
|
||||||
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
self.to_k = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||||
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
self.to_v = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||||
else:
|
else:
|
||||||
self.to_k = None
|
self.to_k = None
|
||||||
self.to_v = None
|
self.to_v = None
|
||||||
|
|
||||||
self.added_proj_bias = added_proj_bias
|
self.added_proj_bias = added_proj_bias
|
||||||
if self.added_kv_proj_dim is not None:
|
if self.added_kv_proj_dim is not None:
|
||||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
self.add_k_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
self.add_v_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||||
if self.context_pre_only is not None:
|
if self.context_pre_only is not None:
|
||||||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
self.add_q_proj = ops.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||||
else:
|
else:
|
||||||
self.add_q_proj = None
|
self.add_q_proj = None
|
||||||
self.add_k_proj = None
|
self.add_k_proj = None
|
||||||
@ -200,13 +203,13 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
if not self.pre_only:
|
if not self.pre_only:
|
||||||
self.to_out = nn.ModuleList([])
|
self.to_out = nn.ModuleList([])
|
||||||
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||||
self.to_out.append(nn.Dropout(dropout))
|
self.to_out.append(nn.Dropout(dropout))
|
||||||
else:
|
else:
|
||||||
self.to_out = None
|
self.to_out = None
|
||||||
|
|
||||||
if self.context_pre_only is not None and not self.context_pre_only:
|
if self.context_pre_only is not None and not self.context_pre_only:
|
||||||
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
|
self.to_add_out = ops.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
|
||||||
else:
|
else:
|
||||||
self.to_add_out = None
|
self.to_add_out = None
|
||||||
|
|
||||||
@ -325,7 +328,7 @@ def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias
|
|||||||
|
|
||||||
def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
||||||
input_dtype = x.dtype
|
input_dtype = x.dtype
|
||||||
if isinstance(norm_layer, (nn.LayerNorm, nn.RMSNorm)):
|
if isinstance(norm_layer, (ops.LayerNorm, ops.RMSNorm)):
|
||||||
if x.ndim == 4:
|
if x.ndim == 4:
|
||||||
x = rearrange(x, "b c h w -> b h w c")
|
x = rearrange(x, "b c h w -> b h w c")
|
||||||
x = norm_layer(x)
|
x = norm_layer(x)
|
||||||
@ -336,14 +339,14 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
|||||||
x = norm_layer(x)
|
x = norm_layer(x)
|
||||||
x = rearrange(x, "b t h w c -> b c t h w")
|
x = rearrange(x, "b t h w c -> b c t h w")
|
||||||
return x.to(input_dtype)
|
return x.to(input_dtype)
|
||||||
if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
|
if isinstance(norm_layer, (ops.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
|
||||||
if x.ndim <= 4:
|
if x.ndim <= 4:
|
||||||
return norm_layer(x).to(input_dtype)
|
return norm_layer(x).to(input_dtype)
|
||||||
if x.ndim == 5:
|
if x.ndim == 5:
|
||||||
t = x.size(2)
|
t = x.size(2)
|
||||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
memory_occupy = x.numel() * x.element_size() / 1024**3
|
memory_occupy = x.numel() * x.element_size() / 1024**3
|
||||||
if isinstance(norm_layer, nn.GroupNorm) and memory_occupy > float("inf"): # TODO: this may be set dynamically from the vae
|
if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > float("inf"): # TODO: this may be set dynamically from the vae
|
||||||
num_chunks = min(4 if x.element_size() == 2 else 2, norm_layer.num_groups)
|
num_chunks = min(4 if x.element_size() == 2 else 2, norm_layer.num_groups)
|
||||||
assert norm_layer.num_groups % num_chunks == 0
|
assert norm_layer.num_groups % num_chunks == 0
|
||||||
num_groups_per_chunk = norm_layer.num_groups // num_chunks
|
num_groups_per_chunk = norm_layer.num_groups // num_chunks
|
||||||
@ -428,7 +431,7 @@ def cache_send_recv(tensor, cache_size, times, memory=None):
|
|||||||
|
|
||||||
return recv_buffer
|
return recv_buffer
|
||||||
|
|
||||||
class InflatedCausalConv3d(torch.nn.Conv3d):
|
class InflatedCausalConv3d(ops.Conv3d):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
@ -677,17 +680,16 @@ class Upsample3D(nn.Module):
|
|||||||
if use_conv_transpose:
|
if use_conv_transpose:
|
||||||
if kernel_size is None:
|
if kernel_size is None:
|
||||||
kernel_size = 4
|
kernel_size = 4
|
||||||
self.conv = nn.ConvTranspose2d(
|
self.conv = ops.ConvTranspose2d(
|
||||||
channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
|
channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
|
||||||
)
|
)
|
||||||
elif use_conv:
|
elif use_conv:
|
||||||
if kernel_size is None:
|
if kernel_size is None:
|
||||||
kernel_size = 3
|
kernel_size = 3
|
||||||
self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
|
self.conv = ops.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
|
||||||
|
|
||||||
conv = self.conv if self.name == "conv" else self.Conv2d_0
|
conv = self.conv if self.name == "conv" else self.Conv2d_0
|
||||||
|
|
||||||
assert type(conv) is not nn.ConvTranspose2d
|
|
||||||
# Note: lora_layer is not passed into constructor in the original implementation.
|
# Note: lora_layer is not passed into constructor in the original implementation.
|
||||||
# So we make a simplification.
|
# So we make a simplification.
|
||||||
conv = InflatedCausalConv3d(
|
conv = InflatedCausalConv3d(
|
||||||
@ -708,7 +710,7 @@ class Upsample3D(nn.Module):
|
|||||||
# [Override] MAGViT v2 implementation
|
# [Override] MAGViT v2 implementation
|
||||||
if not self.interpolate:
|
if not self.interpolate:
|
||||||
upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio
|
upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio
|
||||||
self.upscale_conv = nn.Conv3d(
|
self.upscale_conv = ops.Conv3d(
|
||||||
self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0
|
self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0
|
||||||
)
|
)
|
||||||
identity = (
|
identity = (
|
||||||
@ -892,13 +894,13 @@ class ResnetBlock3D(nn.Module):
|
|||||||
self.skip_time_act = skip_time_act
|
self.skip_time_act = skip_time_act
|
||||||
self.nonlinearity = nn.SiLU()
|
self.nonlinearity = nn.SiLU()
|
||||||
if temb_channels is not None:
|
if temb_channels is not None:
|
||||||
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
self.time_emb_proj = ops.Linear(temb_channels, out_channels)
|
||||||
else:
|
else:
|
||||||
self.time_emb_proj = None
|
self.time_emb_proj = None
|
||||||
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
self.norm1 = ops.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||||
if groups_out is None:
|
if groups_out is None:
|
||||||
groups_out = groups
|
groups_out = groups
|
||||||
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
||||||
self.use_in_shortcut = self.in_channels != out_channels
|
self.use_in_shortcut = self.in_channels != out_channels
|
||||||
self.dropout = torch.nn.Dropout(dropout)
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
self.conv1 = InflatedCausalConv3d(
|
self.conv1 = InflatedCausalConv3d(
|
||||||
@ -1342,7 +1344,7 @@ class Encoder3D(nn.Module):
|
|||||||
|
|
||||||
self.conv_extra_cond.append(
|
self.conv_extra_cond.append(
|
||||||
zero_module(
|
zero_module(
|
||||||
nn.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0)
|
ops.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0)
|
||||||
)
|
)
|
||||||
if self.extra_cond_dim is not None and self.extra_cond_dim > 0
|
if self.extra_cond_dim is not None and self.extra_cond_dim > 0
|
||||||
else None
|
else None
|
||||||
@ -1364,7 +1366,7 @@ class Encoder3D(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# out
|
# out
|
||||||
self.conv_norm_out = nn.GroupNorm(
|
self.conv_norm_out = ops.GroupNorm(
|
||||||
num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
|
num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
|
||||||
)
|
)
|
||||||
self.conv_act = nn.SiLU()
|
self.conv_act = nn.SiLU()
|
||||||
@ -1512,7 +1514,7 @@ class Decoder3D(nn.Module):
|
|||||||
if norm_type == "spatial":
|
if norm_type == "spatial":
|
||||||
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
||||||
else:
|
else:
|
||||||
self.conv_norm_out = nn.GroupNorm(
|
self.conv_norm_out = ops.GroupNorm(
|
||||||
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
|
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
|
||||||
)
|
)
|
||||||
self.conv_act = nn.SiLU()
|
self.conv_act = nn.SiLU()
|
||||||
@ -1553,9 +1555,9 @@ def wavelet_blur(image: Tensor, radius):
|
|||||||
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
|
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
|
||||||
if radius > max_safe_radius:
|
if radius > max_safe_radius:
|
||||||
radius = max_safe_radius
|
radius = max_safe_radius
|
||||||
|
|
||||||
num_channels = image.shape[1]
|
num_channels = image.shape[1]
|
||||||
|
|
||||||
kernel_vals = [
|
kernel_vals = [
|
||||||
[0.0625, 0.125, 0.0625],
|
[0.0625, 0.125, 0.0625],
|
||||||
[0.125, 0.25, 0.125],
|
[0.125, 0.25, 0.125],
|
||||||
@ -1563,21 +1565,21 @@ def wavelet_blur(image: Tensor, radius):
|
|||||||
]
|
]
|
||||||
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||||
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
|
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
|
||||||
|
|
||||||
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
|
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
|
||||||
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
|
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def wavelet_decomposition(image: Tensor, levels: int = 5):
|
def wavelet_decomposition(image: Tensor, levels: int = 5):
|
||||||
high_freq = torch.zeros_like(image)
|
high_freq = torch.zeros_like(image)
|
||||||
|
|
||||||
for i in range(levels):
|
for i in range(levels):
|
||||||
radius = 2 ** i
|
radius = 2 ** i
|
||||||
low_freq = wavelet_blur(image, radius)
|
low_freq = wavelet_blur(image, radius)
|
||||||
high_freq.add_(image).sub_(low_freq)
|
high_freq.add_(image).sub_(low_freq)
|
||||||
image = low_freq
|
image = low_freq
|
||||||
|
|
||||||
return high_freq, low_freq
|
return high_freq, low_freq
|
||||||
|
|
||||||
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||||
@ -1587,19 +1589,19 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
|||||||
if len(content_feat.shape) >= 3:
|
if len(content_feat.shape) >= 3:
|
||||||
# safe_interpolate_operation handles FP16 conversion automatically
|
# safe_interpolate_operation handles FP16 conversion automatically
|
||||||
style_feat = safe_interpolate_operation(
|
style_feat = safe_interpolate_operation(
|
||||||
style_feat,
|
style_feat,
|
||||||
size=content_feat.shape[-2:],
|
size=content_feat.shape[-2:],
|
||||||
mode='bilinear',
|
mode='bilinear',
|
||||||
align_corners=False
|
align_corners=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Decompose both features into frequency components
|
# Decompose both features into frequency components
|
||||||
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||||
del content_low_freq # Free memory immediately
|
del content_low_freq # Free memory immediately
|
||||||
|
|
||||||
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||||
del style_high_freq # Free memory immediately
|
del style_high_freq # Free memory immediately
|
||||||
|
|
||||||
if content_high_freq.shape != style_low_freq.shape:
|
if content_high_freq.shape != style_low_freq.shape:
|
||||||
style_low_freq = safe_interpolate_operation(
|
style_low_freq = safe_interpolate_operation(
|
||||||
style_low_freq,
|
style_low_freq,
|
||||||
@ -1607,9 +1609,9 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
|||||||
mode='bilinear',
|
mode='bilinear',
|
||||||
align_corners=False
|
align_corners=False
|
||||||
)
|
)
|
||||||
|
|
||||||
content_high_freq.add_(style_low_freq)
|
content_high_freq.add_(style_low_freq)
|
||||||
|
|
||||||
return content_high_freq.clamp_(-1.0, 1.0)
|
return content_high_freq.clamp_(-1.0, 1.0)
|
||||||
|
|
||||||
class VideoAutoencoderKL(nn.Module):
|
class VideoAutoencoderKL(nn.Module):
|
||||||
@ -1894,6 +1896,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
|
|
||||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
|
|
||||||
|
input = input.to(x.device)
|
||||||
x = wavelet_reconstruction(x, input)
|
x = wavelet_reconstruction(x, input)
|
||||||
|
|
||||||
x = x.unsqueeze(0)
|
x = x.unsqueeze(0)
|
||||||
|
|||||||
@ -24,7 +24,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
|||||||
x = x.unsqueeze(2)
|
x = x.unsqueeze(2)
|
||||||
|
|
||||||
b, c, d, h, w = x.shape
|
b, c, d, h, w = x.shape
|
||||||
|
|
||||||
sf_s = getattr(vae_model, "spatial_downsample_factor", 8)
|
sf_s = getattr(vae_model, "spatial_downsample_factor", 8)
|
||||||
sf_t = getattr(vae_model, "temporal_downsample_factor", 4)
|
sf_t = getattr(vae_model, "temporal_downsample_factor", 4)
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
|||||||
ti_w = max(1, tile_size[1] // sf_s)
|
ti_w = max(1, tile_size[1] // sf_s)
|
||||||
ov_h = max(0, tile_overlap[0] // sf_s)
|
ov_h = max(0, tile_overlap[0] // sf_s)
|
||||||
ov_w = max(0, tile_overlap[1] // sf_s)
|
ov_w = max(0, tile_overlap[1] // sf_s)
|
||||||
|
|
||||||
target_d = d * sf_t
|
target_d = d * sf_t
|
||||||
target_h = h * sf_s
|
target_h = h * sf_s
|
||||||
target_w = w * sf_s
|
target_w = w * sf_s
|
||||||
@ -47,15 +47,14 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
|||||||
stride_h = max(1, ti_h - ov_h)
|
stride_h = max(1, ti_h - ov_h)
|
||||||
stride_w = max(1, ti_w - ov_w)
|
stride_w = max(1, ti_w - ov_w)
|
||||||
|
|
||||||
storage_device = torch.device("cpu")
|
storage_device = vae_model.device
|
||||||
|
|
||||||
result = None
|
result = None
|
||||||
count = None
|
count = None
|
||||||
|
|
||||||
def run_temporal_chunks(spatial_tile):
|
def run_temporal_chunks(spatial_tile):
|
||||||
chunk_results = []
|
chunk_results = []
|
||||||
t_dim_size = spatial_tile.shape[2]
|
t_dim_size = spatial_tile.shape[2]
|
||||||
|
|
||||||
if encode:
|
if encode:
|
||||||
input_chunk = temporal_size
|
input_chunk = temporal_size
|
||||||
else:
|
else:
|
||||||
@ -63,18 +62,18 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
|||||||
|
|
||||||
for i in range(0, t_dim_size, input_chunk):
|
for i in range(0, t_dim_size, input_chunk):
|
||||||
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
|
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
|
||||||
|
|
||||||
if encode:
|
if encode:
|
||||||
out = vae_model.slicing_encode(t_chunk)
|
out = vae_model.encode(t_chunk)
|
||||||
else:
|
else:
|
||||||
out = vae_model.slicing_decode(t_chunk)
|
out = vae_model.decode_(t_chunk)
|
||||||
|
|
||||||
if isinstance(out, (tuple, list)): out = out[0]
|
if isinstance(out, (tuple, list)): out = out[0]
|
||||||
|
|
||||||
if out.ndim == 4: out = out.unsqueeze(2)
|
if out.ndim == 4: out = out.unsqueeze(2)
|
||||||
|
|
||||||
chunk_results.append(out.to(storage_device))
|
chunk_results.append(out.to(storage_device))
|
||||||
|
|
||||||
return torch.cat(chunk_results, dim=2)
|
return torch.cat(chunk_results, dim=2)
|
||||||
|
|
||||||
ramp_cache = {}
|
ramp_cache = {}
|
||||||
@ -89,7 +88,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
|||||||
|
|
||||||
for y_idx in range(0, h, stride_h):
|
for y_idx in range(0, h, stride_h):
|
||||||
y_end = min(y_idx + ti_h, h)
|
y_end = min(y_idx + ti_h, h)
|
||||||
|
|
||||||
for x_idx in range(0, w, stride_w):
|
for x_idx in range(0, w, stride_w):
|
||||||
x_end = min(x_idx + ti_w, w)
|
x_end = min(x_idx + ti_w, w)
|
||||||
|
|
||||||
@ -131,9 +130,9 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
|||||||
|
|
||||||
valid_d = min(tile_out.shape[2], result.shape[2])
|
valid_d = min(tile_out.shape[2], result.shape[2])
|
||||||
tile_out = tile_out[:, :, :valid_d, :, :]
|
tile_out = tile_out[:, :, :valid_d, :, :]
|
||||||
|
|
||||||
tile_out.mul_(final_weight)
|
tile_out.mul_(final_weight)
|
||||||
|
|
||||||
result[:, :, :valid_d, ys:ye, xs:xe] += tile_out
|
result[:, :, :valid_d, ys:ye, xs:xe] += tile_out
|
||||||
count[:, :, :, ys:ye, xs:xe] += final_weight
|
count[:, :, :, ys:ye, xs:xe] += final_weight
|
||||||
|
|
||||||
@ -141,7 +140,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
|||||||
bar.update(1)
|
bar.update(1)
|
||||||
|
|
||||||
result.div_(count.clamp(min=1e-6))
|
result.div_(count.clamp(min=1e-6))
|
||||||
|
|
||||||
if result.device != x.device:
|
if result.device != x.device:
|
||||||
result = result.to(x.device).to(x.dtype)
|
result = result.to(x.device).to(x.dtype)
|
||||||
|
|
||||||
@ -150,6 +149,18 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def clear_vae_memory(vae_model):
|
||||||
|
for module in vae_model.modules():
|
||||||
|
if hasattr(module, "memory"):
|
||||||
|
module.memory = None
|
||||||
|
if hasattr(vae_model, "original_image_video"):
|
||||||
|
del vae_model.original_image_video
|
||||||
|
|
||||||
|
if hasattr(vae_model, "tiled_args"):
|
||||||
|
del vae_model.tiled_args
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def expand_dims(tensor, ndim):
|
def expand_dims(tensor, ndim):
|
||||||
shape = tensor.shape + (1,) * (ndim - tensor.ndim)
|
shape = tensor.shape + (1,) * (ndim - tensor.ndim)
|
||||||
return tensor.reshape(shape)
|
return tensor.reshape(shape)
|
||||||
@ -261,9 +272,9 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
io.Vae.Input("vae"),
|
io.Vae.Input("vae"),
|
||||||
io.Int.Input("resolution_height", default = 1280, min = 120), # //
|
io.Int.Input("resolution_height", default = 1280, min = 120), # //
|
||||||
io.Int.Input("resolution_width", default = 720, min = 120), # just non-zero value
|
io.Int.Input("resolution_width", default = 720, min = 120), # just non-zero value
|
||||||
io.Int.Input("spatial_tile_size", default = 512, min = -1),
|
io.Int.Input("spatial_tile_size", default = 512, min = 1),
|
||||||
io.Int.Input("temporal_tile_size", default = 8, min = -1),
|
io.Int.Input("temporal_tile_size", default = 8, min = 1),
|
||||||
io.Int.Input("spatial_overlap", default = 64, min = -1),
|
io.Int.Input("spatial_overlap", default = 64, min = 1),
|
||||||
io.Boolean.Input("enable_tiling", default=False)
|
io.Boolean.Input("enable_tiling", default=False)
|
||||||
],
|
],
|
||||||
outputs = [
|
outputs = [
|
||||||
@ -305,7 +316,6 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
images = rearrange(images, "b t c h w -> b c t h w")
|
images = rearrange(images, "b t c h w -> b c t h w")
|
||||||
images = images.to(device)
|
images = images.to(device)
|
||||||
vae_model = vae_model.to(device)
|
vae_model = vae_model.to(device)
|
||||||
vae_model.original_image_video = images
|
|
||||||
|
|
||||||
args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap),
|
args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap),
|
||||||
"temporal_size":temporal_tile_size}
|
"temporal_size":temporal_tile_size}
|
||||||
@ -314,11 +324,14 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
latent = vae_model.encode(images, orig_dims = [o_h, o_w])[0]
|
latent = vae_model.encode(images, orig_dims = [o_h, o_w])[0]
|
||||||
|
|
||||||
|
clear_vae_memory(vae_model)
|
||||||
|
#images = images.to(offload_device)
|
||||||
|
#vae_model = vae_model.to(offload_device)
|
||||||
|
|
||||||
|
vae_model.img_dims = [o_h, o_w]
|
||||||
args["enable_tiling"] = enable_tiling
|
args["enable_tiling"] = enable_tiling
|
||||||
vae_model.tiled_args = args
|
vae_model.tiled_args = args
|
||||||
|
vae_model.original_image_video = images
|
||||||
vae_model = vae_model.to(offload_device)
|
|
||||||
vae_model.img_dims = [o_h, o_w]
|
|
||||||
|
|
||||||
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
|
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
|
||||||
latent = rearrange(latent, "b c ... -> b ... c")
|
latent = rearrange(latent, "b c ... -> b ... c")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user