final changes

This commit is contained in:
Yousef Rafat 2025-12-26 02:08:59 +02:00
parent 7b2e5ef0af
commit 21bc67d7db
3 changed files with 125 additions and 89 deletions

View File

@ -615,6 +615,7 @@ class NaMMAttention(nn.Module):
rope_type: Optional[str], rope_type: Optional[str],
rope_dim: int, rope_dim: int,
shared_weights: bool, shared_weights: bool,
device, dtype, operations,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
@ -624,15 +625,16 @@ class NaMMAttention(nn.Module):
qkv_dim = inner_dim * 3 qkv_dim = inner_dim * 3
self.head_dim = head_dim self.head_dim = head_dim
self.proj_qkv = MMModule( self.proj_qkv = MMModule(
nn.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights operations.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights, device=device, dtype=dtype
) )
self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_weights) self.proj_out = MMModule(operations.Linear, inner_dim, dim, shared_weights=shared_weights, device=device, dtype=dtype)
self.norm_q = MMModule( self.norm_q = MMModule(
qk_norm, qk_norm,
normalized_shape=head_dim, normalized_shape=head_dim,
eps=qk_norm_eps, eps=qk_norm_eps,
elementwise_affine=True, elementwise_affine=True,
shared_weights=shared_weights, shared_weights=shared_weights,
device=device, dtype=dtype
) )
self.norm_k = MMModule( self.norm_k = MMModule(
qk_norm, qk_norm,
@ -640,6 +642,7 @@ class NaMMAttention(nn.Module):
eps=qk_norm_eps, eps=qk_norm_eps,
elementwise_affine=True, elementwise_affine=True,
shared_weights=shared_weights, shared_weights=shared_weights,
device=device, dtype=dtype
) )
@ -795,11 +798,12 @@ class MLP(nn.Module):
self, self,
dim: int, dim: int,
expand_ratio: int, expand_ratio: int,
device, dtype, operations
): ):
super().__init__() super().__init__()
self.proj_in = nn.Linear(dim, dim * expand_ratio) self.proj_in = operations.Linear(dim, dim * expand_ratio, device=device, dtype=dtype)
self.act = nn.GELU("tanh") self.act = nn.GELU("tanh")
self.proj_out = nn.Linear(dim * expand_ratio, dim) self.proj_out = operations.Linear(dim * expand_ratio, dim, device=device, dtype=dtype)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
x = self.proj_in(x) x = self.proj_in(x)
@ -814,13 +818,14 @@ class SwiGLUMLP(nn.Module):
dim: int, dim: int,
expand_ratio: int, expand_ratio: int,
multiple_of: int = 256, multiple_of: int = 256,
device=None, dtype=None, operations=None
): ):
super().__init__() super().__init__()
hidden_dim = int(2 * dim * expand_ratio / 3) hidden_dim = int(2 * dim * expand_ratio / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False) self.proj_in_gate = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype)
self.proj_out = nn.Linear(hidden_dim, dim, bias=False) self.proj_out = operations.Linear(hidden_dim, dim, bias=False, device=device, dtype=dtype)
self.proj_in = nn.Linear(dim, hidden_dim, bias=False) self.proj_in = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
x = x.to(next(self.proj_in.parameters()).device) x = x.to(next(self.proj_in.parameters()).device)
@ -855,11 +860,12 @@ class NaMMSRTransformerBlock(nn.Module):
rope_type: str, rope_type: str,
rope_dim: int, rope_dim: int,
is_last_layer: bool, is_last_layer: bool,
device, dtype, operations,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
dim = MMArg(vid_dim, txt_dim) dim = MMArg(vid_dim, txt_dim)
self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype)
self.attn = NaSwinAttention( self.attn = NaSwinAttention(
vid_dim=vid_dim, vid_dim=vid_dim,
@ -874,17 +880,19 @@ class NaMMSRTransformerBlock(nn.Module):
shared_weights=shared_weights, shared_weights=shared_weights,
window=kwargs.pop("window", None), window=kwargs.pop("window", None),
window_method=kwargs.pop("window_method", None), window_method=kwargs.pop("window_method", None),
device=device, dtype=dtype, operations=operations
) )
self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype)
self.mlp = MMModule( self.mlp = MMModule(
get_mlp(mlp_type), get_mlp(mlp_type),
dim=dim, dim=dim,
expand_ratio=expand_ratio, expand_ratio=expand_ratio,
shared_weights=shared_weights, shared_weights=shared_weights,
vid_only=is_last_layer vid_only=is_last_layer,
device=device, dtype=dtype, operations=operations
) )
self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer) self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype)
self.is_last_layer = is_last_layer self.is_last_layer = is_last_layer
def forward( def forward(
@ -933,11 +941,12 @@ class PatchOut(nn.Module):
out_channels: int, out_channels: int,
patch_size: Union[int, Tuple[int, int, int]], patch_size: Union[int, Tuple[int, int, int]],
dim: int, dim: int,
device, dtype, operations
): ):
super().__init__() super().__init__()
t, h, w = _triple(patch_size) t, h, w = _triple(patch_size)
self.patch_size = t, h, w self.patch_size = t, h, w
self.proj = nn.Linear(dim, out_channels * t * h * w) self.proj = operations.Linear(dim, out_channels * t * h * w, device=device, dtype=dtype)
def forward( def forward(
self, self,
@ -981,11 +990,12 @@ class PatchIn(nn.Module):
in_channels: int, in_channels: int,
patch_size: Union[int, Tuple[int, int, int]], patch_size: Union[int, Tuple[int, int, int]],
dim: int, dim: int,
device, dtype, operations
): ):
super().__init__() super().__init__()
t, h, w = _triple(patch_size) t, h, w = _triple(patch_size)
self.patch_size = t, h, w self.patch_size = t, h, w
self.proj = nn.Linear(in_channels * t * h * w, dim) self.proj = operations.Linear(in_channels * t * h * w, dim, device=device, dtype=dtype)
def forward( def forward(
self, self,
@ -1033,6 +1043,7 @@ class AdaSingle(nn.Module):
emb_dim: int, emb_dim: int,
layers: List[str], layers: List[str],
modes: List[str] = ["in", "out"], modes: List[str] = ["in", "out"],
device = None, dtype = None,
): ):
assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim"
super().__init__() super().__init__()
@ -1041,12 +1052,12 @@ class AdaSingle(nn.Module):
self.layers = layers self.layers = layers
for l in layers: for l in layers:
if "in" in modes: if "in" in modes:
self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5)) self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim, device=device, dtype=dtype) / dim**0.5))
self.register_parameter( self.register_parameter(
f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1) f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1)
) )
if "out" in modes: if "out" in modes:
self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5)) self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim, device=device, dtype=dtype) / dim**0.5))
def forward( def forward(
self, self,
@ -1096,12 +1107,13 @@ class TimeEmbedding(nn.Module):
sinusoidal_dim: int, sinusoidal_dim: int,
hidden_dim: int, hidden_dim: int,
output_dim: int, output_dim: int,
device, dtype, operations
): ):
super().__init__() super().__init__()
self.sinusoidal_dim = sinusoidal_dim self.sinusoidal_dim = sinusoidal_dim
self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim) self.proj_in = operations.Linear(sinusoidal_dim, hidden_dim, device=device, dtype=dtype)
self.proj_hid = nn.Linear(hidden_dim, hidden_dim) self.proj_hid = operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype)
self.proj_out = nn.Linear(hidden_dim, output_dim) self.proj_out = operations.Linear(hidden_dim, output_dim, device=device, dtype=dtype)
self.act = nn.SiLU() self.act = nn.SiLU()
def forward( def forward(
@ -1199,6 +1211,7 @@ class NaDiT(nn.Module):
**kwargs, **kwargs,
): ):
self.dtype = dtype self.dtype = dtype
factory_kwargs = {"device": device, "dtype": dtype}
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
txt_dim = vid_dim txt_dim = vid_dim
emb_dim = vid_dim * 6 emb_dim = vid_dim * 6
@ -1212,15 +1225,16 @@ class NaDiT(nn.Module):
elif len(block_type) != num_layers: elif len(block_type) != num_layers:
raise ValueError("The ``block_type`` list should equal to ``num_layers``.") raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
super().__init__() super().__init__()
self.register_buffer("positive_conditioning", torch.empty((58, 5120))) self.register_buffer("positive_conditioning", torch.empty((58, 5120), device=device, dtype=dtype))
self.register_buffer("negative_conditioning", torch.empty((64, 5120))) self.register_buffer("negative_conditioning", torch.empty((64, 5120), device=device, dtype=dtype))
self.vid_in = NaPatchIn( self.vid_in = NaPatchIn(
in_channels=vid_in_channels, in_channels=vid_in_channels,
patch_size=patch_size, patch_size=patch_size,
dim=vid_dim, dim=vid_dim,
device=device, dtype=dtype, operations=operations
) )
self.txt_in = ( self.txt_in = (
nn.Linear(txt_in_dim, txt_dim) operations.Linear(txt_in_dim, txt_dim, **factory_kwargs)
if txt_in_dim and txt_in_dim != txt_dim if txt_in_dim and txt_in_dim != txt_dim
else nn.Identity() else nn.Identity()
) )
@ -1228,6 +1242,7 @@ class NaDiT(nn.Module):
sinusoidal_dim=256, sinusoidal_dim=256,
hidden_dim=max(vid_dim, txt_dim), hidden_dim=max(vid_dim, txt_dim),
output_dim=emb_dim, output_dim=emb_dim,
device=device, dtype=dtype, operations=operations
) )
if window is None or isinstance(window[0], int): if window is None or isinstance(window[0], int):
@ -1268,7 +1283,9 @@ class NaDiT(nn.Module):
shared_weights=not ( shared_weights=not (
(i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i]
), ),
operations = operations,
**kwargs, **kwargs,
**factory_kwargs
) )
for i in range(num_layers) for i in range(num_layers)
] ]
@ -1277,6 +1294,7 @@ class NaDiT(nn.Module):
out_channels=vid_out_channels, out_channels=vid_out_channels,
patch_size=patch_size, patch_size=patch_size,
dim=vid_dim, dim=vid_dim,
device=device, dtype=dtype, operations=operations
) )
self.need_txt_repeat = block_type[0] in [ self.need_txt_repeat = block_type[0] in [
@ -1291,12 +1309,14 @@ class NaDiT(nn.Module):
normalized_shape=vid_dim, normalized_shape=vid_dim,
eps=norm_eps, eps=norm_eps,
elementwise_affine=True, elementwise_affine=True,
device=device, dtype=dtype
) )
self.vid_out_ada = ada( self.vid_out_ada = ada(
dim=vid_dim, dim=vid_dim,
emb_dim=emb_dim, emb_dim=emb_dim,
layers=["out"], layers=["out"],
modes=["in"], modes=["in"],
device=device, dtype=dtype
) )
self.stop_cfg_index = -1 self.stop_cfg_index = -1

View File

@ -16,6 +16,9 @@ import math
from enum import Enum from enum import Enum
from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND
import comfy.ops
ops = comfy.ops.disable_weight_init
_NORM_LIMIT = float("inf") _NORM_LIMIT = float("inf")
@ -89,9 +92,9 @@ class SpatialNorm(nn.Module):
zq_channels: int, zq_channels: int,
): ):
super().__init__() super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) self.conv_y = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) self.conv_b = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
f_size = f.shape[-2:] f_size = f.shape[-2:]
@ -164,7 +167,7 @@ class Attention(nn.Module):
self.only_cross_attention = only_cross_attention self.only_cross_attention = only_cross_attention
if norm_num_groups is not None: if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) self.group_norm = ops.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
else: else:
self.group_norm = None self.group_norm = None
@ -177,22 +180,22 @@ class Attention(nn.Module):
self.norm_k = None self.norm_k = None
self.norm_cross = None self.norm_cross = None
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_q = ops.Linear(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention: if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes # only relevant for the `AddedKVProcessor` classes
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) self.to_k = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) self.to_v = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
else: else:
self.to_k = None self.to_k = None
self.to_v = None self.to_v = None
self.added_proj_bias = added_proj_bias self.added_proj_bias = added_proj_bias
if self.added_kv_proj_dim is not None: if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) self.add_k_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) self.add_v_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
if self.context_pre_only is not None: if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) self.add_q_proj = ops.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
else: else:
self.add_q_proj = None self.add_q_proj = None
self.add_k_proj = None self.add_k_proj = None
@ -200,13 +203,13 @@ class Attention(nn.Module):
if not self.pre_only: if not self.pre_only:
self.to_out = nn.ModuleList([]) self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout)) self.to_out.append(nn.Dropout(dropout))
else: else:
self.to_out = None self.to_out = None
if self.context_pre_only is not None and not self.context_pre_only: if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) self.to_add_out = ops.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
else: else:
self.to_add_out = None self.to_add_out = None
@ -325,7 +328,7 @@ def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias
def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
input_dtype = x.dtype input_dtype = x.dtype
if isinstance(norm_layer, (nn.LayerNorm, nn.RMSNorm)): if isinstance(norm_layer, (ops.LayerNorm, ops.RMSNorm)):
if x.ndim == 4: if x.ndim == 4:
x = rearrange(x, "b c h w -> b h w c") x = rearrange(x, "b c h w -> b h w c")
x = norm_layer(x) x = norm_layer(x)
@ -336,14 +339,14 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
x = norm_layer(x) x = norm_layer(x)
x = rearrange(x, "b t h w c -> b c t h w") x = rearrange(x, "b t h w c -> b c t h w")
return x.to(input_dtype) return x.to(input_dtype)
if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): if isinstance(norm_layer, (ops.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
if x.ndim <= 4: if x.ndim <= 4:
return norm_layer(x).to(input_dtype) return norm_layer(x).to(input_dtype)
if x.ndim == 5: if x.ndim == 5:
t = x.size(2) t = x.size(2)
x = rearrange(x, "b c t h w -> (b t) c h w") x = rearrange(x, "b c t h w -> (b t) c h w")
memory_occupy = x.numel() * x.element_size() / 1024**3 memory_occupy = x.numel() * x.element_size() / 1024**3
if isinstance(norm_layer, nn.GroupNorm) and memory_occupy > float("inf"): # TODO: this may be set dynamically from the vae if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > float("inf"): # TODO: this may be set dynamically from the vae
num_chunks = min(4 if x.element_size() == 2 else 2, norm_layer.num_groups) num_chunks = min(4 if x.element_size() == 2 else 2, norm_layer.num_groups)
assert norm_layer.num_groups % num_chunks == 0 assert norm_layer.num_groups % num_chunks == 0
num_groups_per_chunk = norm_layer.num_groups // num_chunks num_groups_per_chunk = norm_layer.num_groups // num_chunks
@ -428,7 +431,7 @@ def cache_send_recv(tensor, cache_size, times, memory=None):
return recv_buffer return recv_buffer
class InflatedCausalConv3d(torch.nn.Conv3d): class InflatedCausalConv3d(ops.Conv3d):
def __init__( def __init__(
self, self,
*args, *args,
@ -677,17 +680,16 @@ class Upsample3D(nn.Module):
if use_conv_transpose: if use_conv_transpose:
if kernel_size is None: if kernel_size is None:
kernel_size = 4 kernel_size = 4
self.conv = nn.ConvTranspose2d( self.conv = ops.ConvTranspose2d(
channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
) )
elif use_conv: elif use_conv:
if kernel_size is None: if kernel_size is None:
kernel_size = 3 kernel_size = 3
self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) self.conv = ops.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
conv = self.conv if self.name == "conv" else self.Conv2d_0 conv = self.conv if self.name == "conv" else self.Conv2d_0
assert type(conv) is not nn.ConvTranspose2d
# Note: lora_layer is not passed into constructor in the original implementation. # Note: lora_layer is not passed into constructor in the original implementation.
# So we make a simplification. # So we make a simplification.
conv = InflatedCausalConv3d( conv = InflatedCausalConv3d(
@ -708,7 +710,7 @@ class Upsample3D(nn.Module):
# [Override] MAGViT v2 implementation # [Override] MAGViT v2 implementation
if not self.interpolate: if not self.interpolate:
upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio
self.upscale_conv = nn.Conv3d( self.upscale_conv = ops.Conv3d(
self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0
) )
identity = ( identity = (
@ -892,13 +894,13 @@ class ResnetBlock3D(nn.Module):
self.skip_time_act = skip_time_act self.skip_time_act = skip_time_act
self.nonlinearity = nn.SiLU() self.nonlinearity = nn.SiLU()
if temb_channels is not None: if temb_channels is not None:
self.time_emb_proj = nn.Linear(temb_channels, out_channels) self.time_emb_proj = ops.Linear(temb_channels, out_channels)
else: else:
self.time_emb_proj = None self.time_emb_proj = None
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.norm1 = ops.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
if groups_out is None: if groups_out is None:
groups_out = groups groups_out = groups
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.use_in_shortcut = self.in_channels != out_channels self.use_in_shortcut = self.in_channels != out_channels
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
self.conv1 = InflatedCausalConv3d( self.conv1 = InflatedCausalConv3d(
@ -1342,7 +1344,7 @@ class Encoder3D(nn.Module):
self.conv_extra_cond.append( self.conv_extra_cond.append(
zero_module( zero_module(
nn.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) ops.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0)
) )
if self.extra_cond_dim is not None and self.extra_cond_dim > 0 if self.extra_cond_dim is not None and self.extra_cond_dim > 0
else None else None
@ -1364,7 +1366,7 @@ class Encoder3D(nn.Module):
) )
# out # out
self.conv_norm_out = nn.GroupNorm( self.conv_norm_out = ops.GroupNorm(
num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
) )
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
@ -1512,7 +1514,7 @@ class Decoder3D(nn.Module):
if norm_type == "spatial": if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else: else:
self.conv_norm_out = nn.GroupNorm( self.conv_norm_out = ops.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
) )
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
@ -1894,6 +1896,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
x = rearrange(x, "b c t h w -> (b t) c h w") x = rearrange(x, "b c t h w -> (b t) c h w")
input = input.to(x.device)
x = wavelet_reconstruction(x, input) x = wavelet_reconstruction(x, input)
x = x.unsqueeze(0) x = x.unsqueeze(0)

View File

@ -47,8 +47,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
stride_h = max(1, ti_h - ov_h) stride_h = max(1, ti_h - ov_h)
stride_w = max(1, ti_w - ov_w) stride_w = max(1, ti_w - ov_w)
storage_device = torch.device("cpu") storage_device = vae_model.device
result = None result = None
count = None count = None
@ -65,9 +64,9 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :] t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
if encode: if encode:
out = vae_model.slicing_encode(t_chunk) out = vae_model.encode(t_chunk)
else: else:
out = vae_model.slicing_decode(t_chunk) out = vae_model.decode_(t_chunk)
if isinstance(out, (tuple, list)): out = out[0] if isinstance(out, (tuple, list)): out = out[0]
@ -150,6 +149,18 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
return result return result
def clear_vae_memory(vae_model):
for module in vae_model.modules():
if hasattr(module, "memory"):
module.memory = None
if hasattr(vae_model, "original_image_video"):
del vae_model.original_image_video
if hasattr(vae_model, "tiled_args"):
del vae_model.tiled_args
gc.collect()
torch.cuda.empty_cache()
def expand_dims(tensor, ndim): def expand_dims(tensor, ndim):
shape = tensor.shape + (1,) * (ndim - tensor.ndim) shape = tensor.shape + (1,) * (ndim - tensor.ndim)
return tensor.reshape(shape) return tensor.reshape(shape)
@ -261,9 +272,9 @@ class SeedVR2InputProcessing(io.ComfyNode):
io.Vae.Input("vae"), io.Vae.Input("vae"),
io.Int.Input("resolution_height", default = 1280, min = 120), # // io.Int.Input("resolution_height", default = 1280, min = 120), # //
io.Int.Input("resolution_width", default = 720, min = 120), # just non-zero value io.Int.Input("resolution_width", default = 720, min = 120), # just non-zero value
io.Int.Input("spatial_tile_size", default = 512, min = -1), io.Int.Input("spatial_tile_size", default = 512, min = 1),
io.Int.Input("temporal_tile_size", default = 8, min = -1), io.Int.Input("temporal_tile_size", default = 8, min = 1),
io.Int.Input("spatial_overlap", default = 64, min = -1), io.Int.Input("spatial_overlap", default = 64, min = 1),
io.Boolean.Input("enable_tiling", default=False) io.Boolean.Input("enable_tiling", default=False)
], ],
outputs = [ outputs = [
@ -305,7 +316,6 @@ class SeedVR2InputProcessing(io.ComfyNode):
images = rearrange(images, "b t c h w -> b c t h w") images = rearrange(images, "b t c h w -> b c t h w")
images = images.to(device) images = images.to(device)
vae_model = vae_model.to(device) vae_model = vae_model.to(device)
vae_model.original_image_video = images
args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap), args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap),
"temporal_size":temporal_tile_size} "temporal_size":temporal_tile_size}
@ -314,11 +324,14 @@ class SeedVR2InputProcessing(io.ComfyNode):
else: else:
latent = vae_model.encode(images, orig_dims = [o_h, o_w])[0] latent = vae_model.encode(images, orig_dims = [o_h, o_w])[0]
clear_vae_memory(vae_model)
#images = images.to(offload_device)
#vae_model = vae_model.to(offload_device)
vae_model.img_dims = [o_h, o_w]
args["enable_tiling"] = enable_tiling args["enable_tiling"] = enable_tiling
vae_model.tiled_args = args vae_model.tiled_args = args
vae_model.original_image_video = images
vae_model = vae_model.to(offload_device)
vae_model.img_dims = [o_h, o_w]
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
latent = rearrange(latent, "b c ... -> b ... c") latent = rearrange(latent, "b c ... -> b ... c")