mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
tune up mmaudio, nodes_eps
This commit is contained in:
parent
f1016ef1c1
commit
dc94081155
@ -4,7 +4,8 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn, sin, pow
|
from torch import nn, sin, pow
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
import comfy.model_management
|
from ....model_management import cast_to
|
||||||
|
|
||||||
|
|
||||||
class Snake(nn.Module):
|
class Snake(nn.Module):
|
||||||
'''
|
'''
|
||||||
@ -22,6 +23,7 @@ class Snake(nn.Module):
|
|||||||
>>> x = torch.randn(256)
|
>>> x = torch.randn(256)
|
||||||
>>> x = a1(x)
|
>>> x = a1(x)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||||
'''
|
'''
|
||||||
Initialization.
|
Initialization.
|
||||||
@ -51,7 +53,7 @@ class Snake(nn.Module):
|
|||||||
Applies the function to the input elementwise.
|
Applies the function to the input elementwise.
|
||||||
Snake ∶= x + 1/a * sin^2 (xa)
|
Snake ∶= x + 1/a * sin^2 (xa)
|
||||||
'''
|
'''
|
||||||
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
alpha = cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||||
if self.alpha_logscale:
|
if self.alpha_logscale:
|
||||||
alpha = torch.exp(alpha)
|
alpha = torch.exp(alpha)
|
||||||
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||||
@ -76,6 +78,7 @@ class SnakeBeta(nn.Module):
|
|||||||
>>> x = torch.randn(256)
|
>>> x = torch.randn(256)
|
||||||
>>> x = a1(x)
|
>>> x = a1(x)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||||
'''
|
'''
|
||||||
Initialization.
|
Initialization.
|
||||||
@ -110,8 +113,8 @@ class SnakeBeta(nn.Module):
|
|||||||
Applies the function to the input elementwise.
|
Applies the function to the input elementwise.
|
||||||
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||||
'''
|
'''
|
||||||
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
alpha = cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||||
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
|
beta = cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
|
||||||
if self.alpha_logscale:
|
if self.alpha_logscale:
|
||||||
alpha = torch.exp(alpha)
|
alpha = torch.exp(alpha)
|
||||||
beta = torch.exp(beta)
|
beta = torch.exp(beta)
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import math
|
import math
|
||||||
import comfy.model_management
|
from ....model_management import cast_to
|
||||||
|
|
||||||
if 'sinc' in dir(torch):
|
if 'sinc' in dir(torch):
|
||||||
sinc = torch.sinc
|
sinc = torch.sinc
|
||||||
@ -23,17 +23,17 @@ else:
|
|||||||
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||||
# https://adefossez.github.io/julius/julius/lowpass.html
|
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||||
# LICENSE is in incl_licenses directory.
|
# LICENSE is in incl_licenses directory.
|
||||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||||
even = (kernel_size % 2 == 0)
|
even = (kernel_size % 2 == 0)
|
||||||
half_size = kernel_size // 2
|
half_size = kernel_size // 2
|
||||||
|
|
||||||
#For kaiser window
|
# For kaiser window
|
||||||
delta_f = 4 * half_width
|
delta_f = 4 * half_width
|
||||||
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||||
if A > 50.:
|
if A > 50.:
|
||||||
beta = 0.1102 * (A - 8.7)
|
beta = 0.1102 * (A - 8.7)
|
||||||
elif A >= 21.:
|
elif A >= 21.:
|
||||||
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.)
|
||||||
else:
|
else:
|
||||||
beta = 0.
|
beta = 0.
|
||||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||||
@ -80,14 +80,14 @@ class LowPassFilter1d(nn.Module):
|
|||||||
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||||
self.register_buffer("filter", filter)
|
self.register_buffer("filter", filter)
|
||||||
|
|
||||||
#input [B, C, T]
|
# input [B, C, T]
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
_, C, _ = x.shape
|
_, C, _ = x.shape
|
||||||
|
|
||||||
if self.padding:
|
if self.padding:
|
||||||
x = F.pad(x, (self.pad_left, self.pad_right),
|
x = F.pad(x, (self.pad_left, self.pad_right),
|
||||||
mode=self.padding_mode)
|
mode=self.padding_mode)
|
||||||
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
|
out = F.conv1d(x, cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
|
||||||
stride=self.stride, groups=C)
|
stride=self.stride, groups=C)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
@ -113,7 +113,7 @@ class UpSample1d(nn.Module):
|
|||||||
|
|
||||||
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
||||||
x = self.ratio * F.conv_transpose1d(
|
x = self.ratio * F.conv_transpose1d(
|
||||||
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
|
x, cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
|
||||||
x = x[..., self.pad_left:-self.pad_right]
|
x = x[..., self.pad_left:-self.pad_right]
|
||||||
|
|
||||||
return x
|
return x
|
||||||
@ -134,6 +134,7 @@ class DownSample1d(nn.Module):
|
|||||||
|
|
||||||
return xx
|
return xx
|
||||||
|
|
||||||
|
|
||||||
class Activation1d(nn.Module):
|
class Activation1d(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
activation,
|
activation,
|
||||||
|
|||||||
@ -8,10 +8,6 @@ from .vae import VAE_16k
|
|||||||
from .bigvgan import BigVGANVocoder
|
from .bigvgan import BigVGANVocoder
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
try:
|
|
||||||
import torchaudio
|
|
||||||
except:
|
|
||||||
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
|
|
||||||
|
|
||||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
|
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
|
||||||
return norm_fn(torch.clamp(x, min=clip_val) * C)
|
return norm_fn(torch.clamp(x, min=clip_val) * C)
|
||||||
@ -21,19 +17,20 @@ def spectral_normalize_torch(magnitudes, norm_fn):
|
|||||||
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
|
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class MelConverter(nn.Module):
|
class MelConverter(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
sampling_rate: float,
|
sampling_rate: float,
|
||||||
n_fft: int,
|
n_fft: int,
|
||||||
num_mels: int,
|
num_mels: int,
|
||||||
hop_size: int,
|
hop_size: int,
|
||||||
win_size: int,
|
win_size: int,
|
||||||
fmin: float,
|
fmin: float,
|
||||||
fmax: float,
|
fmax: float,
|
||||||
norm_fn,
|
norm_fn,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.sampling_rate = sampling_rate
|
self.sampling_rate = sampling_rate
|
||||||
@ -89,26 +86,27 @@ class MelConverter(nn.Module):
|
|||||||
|
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
class AudioAutoencoder(nn.Module):
|
class AudioAutoencoder(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
# ckpt_path: str,
|
# ckpt_path: str,
|
||||||
mode=Literal['16k', '44k'],
|
mode=Literal['16k', '44k'],
|
||||||
need_vae_encoder: bool = True,
|
need_vae_encoder: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert mode == "16k", "Only 16k mode is supported currently."
|
assert mode == "16k", "Only 16k mode is supported currently."
|
||||||
self.mel_converter = MelConverter(sampling_rate=16_000,
|
self.mel_converter = MelConverter(sampling_rate=16_000,
|
||||||
n_fft=1024,
|
n_fft=1024,
|
||||||
num_mels=80,
|
num_mels=80,
|
||||||
hop_size=256,
|
hop_size=256,
|
||||||
win_size=1024,
|
win_size=1024,
|
||||||
fmin=0,
|
fmin=0,
|
||||||
fmax=8_000,
|
fmax=8_000,
|
||||||
norm_fn=torch.log10)
|
norm_fn=torch.log10)
|
||||||
|
|
||||||
self.vae = VAE_16k().eval()
|
self.vae = VAE_16k().eval()
|
||||||
|
|
||||||
@ -145,11 +143,13 @@ class AudioAutoencoder(nn.Module):
|
|||||||
mel_decoded = self.vae.decode(z)
|
mel_decoded = self.vae.decode(z)
|
||||||
audio = self.vocoder(mel_decoded)
|
audio = self.vocoder(mel_decoded)
|
||||||
|
|
||||||
|
import torchaudio
|
||||||
audio = torchaudio.functional.resample(audio, 16000, 44100)
|
audio = torchaudio.functional.resample(audio, 16000, 44100)
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def encode(self, audio):
|
def encode(self, audio):
|
||||||
|
import torchaudio
|
||||||
audio = audio.mean(dim=1)
|
audio = audio.mean(dim=1)
|
||||||
audio = torchaudio.functional.resample(audio, 44100, 16000)
|
audio = torchaudio.functional.resample(audio, 44100, 16000)
|
||||||
dist = self.encode_audio(audio)
|
dist = self.encode_audio(audio)
|
||||||
|
|||||||
@ -9,12 +9,13 @@ import torch.nn as nn
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from . import activations
|
from . import activations
|
||||||
from .alias_free_torch import Activation1d
|
from .alias_free_torch import Activation1d
|
||||||
import comfy.ops
|
from ....ops import disable_weight_init as ops
|
||||||
ops = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
def get_padding(kernel_size, dilation=1):
|
def get_padding(kernel_size, dilation=1):
|
||||||
return int((kernel_size * dilation - dilation) / 2)
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
|
||||||
class AMPBlock1(torch.nn.Module):
|
class AMPBlock1(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
|
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
|
||||||
@ -22,19 +23,19 @@ class AMPBlock1(torch.nn.Module):
|
|||||||
self.h = h
|
self.h = h
|
||||||
|
|
||||||
self.convs1 = nn.ModuleList([
|
self.convs1 = nn.ModuleList([
|
||||||
ops.Conv1d(channels,
|
ops.Conv1d(channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
1,
|
1,
|
||||||
dilation=dilation[0],
|
dilation=dilation[0],
|
||||||
padding=get_padding(kernel_size, dilation[0])),
|
padding=get_padding(kernel_size, dilation[0])),
|
||||||
ops.Conv1d(channels,
|
ops.Conv1d(channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
1,
|
1,
|
||||||
dilation=dilation[1],
|
dilation=dilation[1],
|
||||||
padding=get_padding(kernel_size, dilation[1])),
|
padding=get_padding(kernel_size, dilation[1])),
|
||||||
ops.Conv1d(channels,
|
ops.Conv1d(channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
1,
|
1,
|
||||||
@ -43,19 +44,19 @@ class AMPBlock1(torch.nn.Module):
|
|||||||
])
|
])
|
||||||
|
|
||||||
self.convs2 = nn.ModuleList([
|
self.convs2 = nn.ModuleList([
|
||||||
ops.Conv1d(channels,
|
ops.Conv1d(channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
1,
|
1,
|
||||||
dilation=1,
|
dilation=1,
|
||||||
padding=get_padding(kernel_size, 1)),
|
padding=get_padding(kernel_size, 1)),
|
||||||
ops.Conv1d(channels,
|
ops.Conv1d(channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
1,
|
1,
|
||||||
dilation=1,
|
dilation=1,
|
||||||
padding=get_padding(kernel_size, 1)),
|
padding=get_padding(kernel_size, 1)),
|
||||||
ops.Conv1d(channels,
|
ops.Conv1d(channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
1,
|
1,
|
||||||
@ -101,13 +102,13 @@ class AMPBlock2(torch.nn.Module):
|
|||||||
self.h = h
|
self.h = h
|
||||||
|
|
||||||
self.convs = nn.ModuleList([
|
self.convs = nn.ModuleList([
|
||||||
ops.Conv1d(channels,
|
ops.Conv1d(channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
1,
|
1,
|
||||||
dilation=dilation[0],
|
dilation=dilation[0],
|
||||||
padding=get_padding(kernel_size, dilation[0])),
|
padding=get_padding(kernel_size, dilation[0])),
|
||||||
ops.Conv1d(channels,
|
ops.Conv1d(channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
1,
|
1,
|
||||||
@ -165,8 +166,8 @@ class BigVGANVocoder(torch.nn.Module):
|
|||||||
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||||
self.ups.append(
|
self.ups.append(
|
||||||
nn.ModuleList([
|
nn.ModuleList([
|
||||||
ops.ConvTranspose1d(h.upsample_initial_channel // (2**i),
|
ops.ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
|
||||||
h.upsample_initial_channel // (2**(i + 1)),
|
h.upsample_initial_channel // (2 ** (i + 1)),
|
||||||
k,
|
k,
|
||||||
u,
|
u,
|
||||||
padding=(k - u) // 2)
|
padding=(k - u) // 2)
|
||||||
@ -175,7 +176,7 @@ class BigVGANVocoder(torch.nn.Module):
|
|||||||
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
||||||
self.resblocks = nn.ModuleList()
|
self.resblocks = nn.ModuleList()
|
||||||
for i in range(len(self.ups)):
|
for i in range(len(self.ups)):
|
||||||
ch = h.upsample_initial_channel // (2**(i + 1))
|
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
||||||
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||||
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
|
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
|
||||||
|
|
||||||
@ -193,7 +194,6 @@ class BigVGANVocoder(torch.nn.Module):
|
|||||||
|
|
||||||
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
|
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# pre conv
|
# pre conv
|
||||||
x = self.conv_pre(x)
|
x = self.conv_pre(x)
|
||||||
|
|||||||
@ -1,17 +1,14 @@
|
|||||||
import logging
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
|
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
|
||||||
Upsample1D, nonlinearity)
|
Upsample1D, nonlinearity)
|
||||||
from .distributions import DiagonalGaussianDistribution
|
from .distributions import DiagonalGaussianDistribution
|
||||||
|
|
||||||
import comfy.ops
|
from ....ops import disable_weight_init as ops
|
||||||
ops = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
log = logging.getLogger()
|
|
||||||
|
|
||||||
DATA_MEAN_80D = [
|
DATA_MEAN_80D = [
|
||||||
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
|
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
|
||||||
@ -68,11 +65,11 @@ DATA_STD_128D = [
|
|||||||
class VAE(nn.Module):
|
class VAE(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
data_dim: int,
|
data_dim: int,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -135,12 +132,12 @@ class VAE(nn.Module):
|
|||||||
return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)
|
return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sample_posterior: bool = True,
|
sample_posterior: bool = True,
|
||||||
rng: Optional[torch.Generator] = None,
|
rng: Optional[torch.Generator] = None,
|
||||||
normalize: bool = True,
|
normalize: bool = True,
|
||||||
unnormalize: bool = True,
|
unnormalize: bool = True,
|
||||||
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
|
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
|
||||||
|
|
||||||
posterior = self.encode(x, normalize=normalize)
|
posterior = self.encode(x, normalize=normalize)
|
||||||
@ -190,7 +187,7 @@ class Encoder1D(nn.Module):
|
|||||||
self.attn_layers = attn_layers
|
self.attn_layers = attn_layers
|
||||||
self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||||
|
|
||||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
in_ch_mult = (1,) + tuple(ch_mult)
|
||||||
self.in_ch_mult = in_ch_mult
|
self.in_ch_mult = in_ch_mult
|
||||||
# downsampling
|
# downsampling
|
||||||
self.down = nn.ModuleList()
|
self.down = nn.ModuleList()
|
||||||
@ -229,8 +226,8 @@ class Encoder1D(nn.Module):
|
|||||||
|
|
||||||
# end
|
# end
|
||||||
self.conv_out = ops.Conv1d(block_in,
|
self.conv_out = ops.Conv1d(block_in,
|
||||||
2 * embed_dim if double_z else embed_dim,
|
2 * embed_dim if double_z else embed_dim,
|
||||||
kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||||
|
|
||||||
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
||||||
|
|
||||||
@ -355,4 +352,3 @@ def get_my_vae(name: str, **kwargs) -> VAE:
|
|||||||
if name == '44k':
|
if name == '44k':
|
||||||
return VAE_44k(**kwargs)
|
return VAE_44k(**kwargs)
|
||||||
raise ValueError(f'Unknown model: {name}')
|
raise ValueError(f'Unknown model: {name}')
|
||||||
|
|
||||||
|
|||||||
@ -1,17 +1,19 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
from ...modules.diffusionmodules.model import vae_attention
|
||||||
import math
|
import math
|
||||||
import comfy.ops
|
from ....ops import disable_weight_init as ops
|
||||||
ops = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
def nonlinearity(x):
|
def nonlinearity(x):
|
||||||
# swish
|
# swish
|
||||||
return torch.nn.functional.silu(x) / 0.596
|
return torch.nn.functional.silu(x) / 0.596
|
||||||
|
|
||||||
|
|
||||||
def mp_sum(a, b, t=0.5):
|
def mp_sum(a, b, t=0.5):
|
||||||
return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2)
|
return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
|
||||||
|
|
||||||
|
|
||||||
def normalize(x, dim=None, eps=1e-4):
|
def normalize(x, dim=None, eps=1e-4):
|
||||||
if dim is None:
|
if dim is None:
|
||||||
@ -20,6 +22,7 @@ def normalize(x, dim=None, eps=1e-4):
|
|||||||
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
||||||
return x / norm.to(x.dtype)
|
return x / norm.to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class ResnetBlock1D(nn.Module):
|
class ResnetBlock1D(nn.Module):
|
||||||
|
|
||||||
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
|
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
|
||||||
|
|||||||
@ -14,6 +14,7 @@ class EpsilonScaling(io.ComfyNode):
|
|||||||
which can significantly improve sample quality. This implementation uses the "uniform schedule"
|
which can significantly improve sample quality. This implementation uses the "uniform schedule"
|
||||||
recommended by the paper for its practicality and effectiveness.
|
recommended by the paper for its practicality and effectiveness.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
@ -66,7 +67,7 @@ class EpsilonScaling(io.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
def compute_tsr_rescaling_factor(
|
def compute_tsr_rescaling_factor(
|
||||||
snr: torch.Tensor, tsr_k: float, tsr_variance: float
|
snr: torch.Tensor, tsr_k: float, tsr_variance: float
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Compute the rescaling score ratio in Temporal Score Rescaling.
|
"""Compute the rescaling score ratio in Temporal Score Rescaling.
|
||||||
|
|
||||||
@ -74,7 +75,7 @@ def compute_tsr_rescaling_factor(
|
|||||||
"""
|
"""
|
||||||
posinf_mask = torch.isposinf(snr)
|
posinf_mask = torch.isposinf(snr)
|
||||||
rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1)
|
rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1)
|
||||||
return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k
|
return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k
|
||||||
|
|
||||||
|
|
||||||
class TemporalScoreRescaling(io.ComfyNode):
|
class TemporalScoreRescaling(io.ComfyNode):
|
||||||
@ -125,7 +126,7 @@ class TemporalScoreRescaling(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput:
|
def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput:
|
||||||
tsr_variance = tsr_sigma**2
|
tsr_variance = tsr_sigma ** 2
|
||||||
|
|
||||||
def temporal_score_rescaling(args):
|
def temporal_score_rescaling(args):
|
||||||
denoised = args["denoised"]
|
denoised = args["denoised"]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user