mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
tune up mmaudio, nodes_eps
This commit is contained in:
parent
f1016ef1c1
commit
dc94081155
@ -4,7 +4,8 @@
|
||||
import torch
|
||||
from torch import nn, sin, pow
|
||||
from torch.nn import Parameter
|
||||
import comfy.model_management
|
||||
from ....model_management import cast_to
|
||||
|
||||
|
||||
class Snake(nn.Module):
|
||||
'''
|
||||
@ -22,6 +23,7 @@ class Snake(nn.Module):
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
@ -51,7 +53,7 @@ class Snake(nn.Module):
|
||||
Applies the function to the input elementwise.
|
||||
Snake ∶= x + 1/a * sin^2 (xa)
|
||||
'''
|
||||
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
alpha = cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
@ -76,6 +78,7 @@ class SnakeBeta(nn.Module):
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
@ -110,8 +113,8 @@ class SnakeBeta(nn.Module):
|
||||
Applies the function to the input elementwise.
|
||||
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||
'''
|
||||
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
|
||||
alpha = cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
beta = cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
|
||||
@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
import comfy.model_management
|
||||
from ....model_management import cast_to
|
||||
|
||||
if 'sinc' in dir(torch):
|
||||
sinc = torch.sinc
|
||||
@ -23,17 +23,17 @@ else:
|
||||
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||
even = (kernel_size % 2 == 0)
|
||||
half_size = kernel_size // 2
|
||||
|
||||
#For kaiser window
|
||||
# For kaiser window
|
||||
delta_f = 4 * half_width
|
||||
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if A > 50.:
|
||||
beta = 0.1102 * (A - 8.7)
|
||||
elif A >= 21.:
|
||||
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
||||
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.)
|
||||
else:
|
||||
beta = 0.
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
@ -80,14 +80,14 @@ class LowPassFilter1d(nn.Module):
|
||||
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
#input [B, C, T]
|
||||
# input [B, C, T]
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right),
|
||||
mode=self.padding_mode)
|
||||
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
|
||||
out = F.conv1d(x, cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
|
||||
stride=self.stride, groups=C)
|
||||
|
||||
return out
|
||||
@ -113,7 +113,7 @@ class UpSample1d(nn.Module):
|
||||
|
||||
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
|
||||
x, cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
|
||||
x = x[..., self.pad_left:-self.pad_right]
|
||||
|
||||
return x
|
||||
@ -134,6 +134,7 @@ class DownSample1d(nn.Module):
|
||||
|
||||
return xx
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
def __init__(self,
|
||||
activation,
|
||||
|
||||
@ -8,10 +8,6 @@ from .vae import VAE_16k
|
||||
from .bigvgan import BigVGANVocoder
|
||||
import logging
|
||||
|
||||
try:
|
||||
import torchaudio
|
||||
except:
|
||||
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
|
||||
return norm_fn(torch.clamp(x, min=clip_val) * C)
|
||||
@ -21,19 +17,20 @@ def spectral_normalize_torch(magnitudes, norm_fn):
|
||||
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
|
||||
return output
|
||||
|
||||
|
||||
class MelConverter(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sampling_rate: float,
|
||||
n_fft: int,
|
||||
num_mels: int,
|
||||
hop_size: int,
|
||||
win_size: int,
|
||||
fmin: float,
|
||||
fmax: float,
|
||||
norm_fn,
|
||||
self,
|
||||
*,
|
||||
sampling_rate: float,
|
||||
n_fft: int,
|
||||
num_mels: int,
|
||||
hop_size: int,
|
||||
win_size: int,
|
||||
fmin: float,
|
||||
fmax: float,
|
||||
norm_fn,
|
||||
):
|
||||
super().__init__()
|
||||
self.sampling_rate = sampling_rate
|
||||
@ -89,26 +86,27 @@ class MelConverter(nn.Module):
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
class AudioAutoencoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
# ckpt_path: str,
|
||||
mode=Literal['16k', '44k'],
|
||||
need_vae_encoder: bool = True,
|
||||
self,
|
||||
*,
|
||||
# ckpt_path: str,
|
||||
mode=Literal['16k', '44k'],
|
||||
need_vae_encoder: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert mode == "16k", "Only 16k mode is supported currently."
|
||||
self.mel_converter = MelConverter(sampling_rate=16_000,
|
||||
n_fft=1024,
|
||||
num_mels=80,
|
||||
hop_size=256,
|
||||
win_size=1024,
|
||||
fmin=0,
|
||||
fmax=8_000,
|
||||
norm_fn=torch.log10)
|
||||
n_fft=1024,
|
||||
num_mels=80,
|
||||
hop_size=256,
|
||||
win_size=1024,
|
||||
fmin=0,
|
||||
fmax=8_000,
|
||||
norm_fn=torch.log10)
|
||||
|
||||
self.vae = VAE_16k().eval()
|
||||
|
||||
@ -145,11 +143,13 @@ class AudioAutoencoder(nn.Module):
|
||||
mel_decoded = self.vae.decode(z)
|
||||
audio = self.vocoder(mel_decoded)
|
||||
|
||||
import torchaudio
|
||||
audio = torchaudio.functional.resample(audio, 16000, 44100)
|
||||
return audio
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, audio):
|
||||
import torchaudio
|
||||
audio = audio.mean(dim=1)
|
||||
audio = torchaudio.functional.resample(audio, 44100, 16000)
|
||||
dist = self.encode_audio(audio)
|
||||
|
||||
@ -9,12 +9,13 @@ import torch.nn as nn
|
||||
from types import SimpleNamespace
|
||||
from . import activations
|
||||
from .alias_free_torch import Activation1d
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
from ....ops import disable_weight_init as ops
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
class AMPBlock1(torch.nn.Module):
|
||||
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
|
||||
@ -22,19 +23,19 @@ class AMPBlock1(torch.nn.Module):
|
||||
self.h = h
|
||||
|
||||
self.convs1 = nn.ModuleList([
|
||||
ops.Conv1d(channels,
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0])),
|
||||
ops.Conv1d(channels,
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1])),
|
||||
ops.Conv1d(channels,
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
@ -43,19 +44,19 @@ class AMPBlock1(torch.nn.Module):
|
||||
])
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
ops.Conv1d(channels,
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1)),
|
||||
ops.Conv1d(channels,
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1)),
|
||||
ops.Conv1d(channels,
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
@ -101,13 +102,13 @@ class AMPBlock2(torch.nn.Module):
|
||||
self.h = h
|
||||
|
||||
self.convs = nn.ModuleList([
|
||||
ops.Conv1d(channels,
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0])),
|
||||
ops.Conv1d(channels,
|
||||
ops.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
@ -165,8 +166,8 @@ class BigVGANVocoder(torch.nn.Module):
|
||||
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
nn.ModuleList([
|
||||
ops.ConvTranspose1d(h.upsample_initial_channel // (2**i),
|
||||
h.upsample_initial_channel // (2**(i + 1)),
|
||||
ops.ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
|
||||
h.upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2)
|
||||
@ -175,7 +176,7 @@ class BigVGANVocoder(torch.nn.Module):
|
||||
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2**(i + 1))
|
||||
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
|
||||
|
||||
@ -193,7 +194,6 @@ class BigVGANVocoder(torch.nn.Module):
|
||||
|
||||
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# pre conv
|
||||
x = self.conv_pre(x)
|
||||
|
||||
@ -1,17 +1,14 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
|
||||
Upsample1D, nonlinearity)
|
||||
Upsample1D, nonlinearity)
|
||||
from .distributions import DiagonalGaussianDistribution
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
from ....ops import disable_weight_init as ops
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
DATA_MEAN_80D = [
|
||||
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
|
||||
@ -68,11 +65,11 @@ DATA_STD_128D = [
|
||||
class VAE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
data_dim: int,
|
||||
embed_dim: int,
|
||||
hidden_dim: int,
|
||||
self,
|
||||
*,
|
||||
data_dim: int,
|
||||
embed_dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -135,12 +132,12 @@ class VAE(nn.Module):
|
||||
return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sample_posterior: bool = True,
|
||||
rng: Optional[torch.Generator] = None,
|
||||
normalize: bool = True,
|
||||
unnormalize: bool = True,
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sample_posterior: bool = True,
|
||||
rng: Optional[torch.Generator] = None,
|
||||
normalize: bool = True,
|
||||
unnormalize: bool = True,
|
||||
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
|
||||
|
||||
posterior = self.encode(x, normalize=normalize)
|
||||
@ -190,7 +187,7 @@ class Encoder1D(nn.Module):
|
||||
self.attn_layers = attn_layers
|
||||
self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
|
||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
# downsampling
|
||||
self.down = nn.ModuleList()
|
||||
@ -229,8 +226,8 @@ class Encoder1D(nn.Module):
|
||||
|
||||
# end
|
||||
self.conv_out = ops.Conv1d(block_in,
|
||||
2 * embed_dim if double_z else embed_dim,
|
||||
kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
2 * embed_dim if double_z else embed_dim,
|
||||
kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||
|
||||
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
||||
|
||||
@ -355,4 +352,3 @@ def get_my_vae(name: str, **kwargs) -> VAE:
|
||||
if name == '44k':
|
||||
return VAE_44k(**kwargs)
|
||||
raise ValueError(f'Unknown model: {name}')
|
||||
|
||||
|
||||
@ -1,17 +1,19 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||
from ...modules.diffusionmodules.model import vae_attention
|
||||
import math
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
from ....ops import disable_weight_init as ops
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return torch.nn.functional.silu(x) / 0.596
|
||||
|
||||
|
||||
def mp_sum(a, b, t=0.5):
|
||||
return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2)
|
||||
return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
|
||||
|
||||
|
||||
def normalize(x, dim=None, eps=1e-4):
|
||||
if dim is None:
|
||||
@ -20,6 +22,7 @@ def normalize(x, dim=None, eps=1e-4):
|
||||
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
||||
return x / norm.to(x.dtype)
|
||||
|
||||
|
||||
class ResnetBlock1D(nn.Module):
|
||||
|
||||
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
|
||||
|
||||
@ -14,6 +14,7 @@ class EpsilonScaling(io.ComfyNode):
|
||||
which can significantly improve sample quality. This implementation uses the "uniform schedule"
|
||||
recommended by the paper for its practicality and effectiveness.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
@ -66,7 +67,7 @@ class EpsilonScaling(io.ComfyNode):
|
||||
|
||||
|
||||
def compute_tsr_rescaling_factor(
|
||||
snr: torch.Tensor, tsr_k: float, tsr_variance: float
|
||||
snr: torch.Tensor, tsr_k: float, tsr_variance: float
|
||||
) -> torch.Tensor:
|
||||
"""Compute the rescaling score ratio in Temporal Score Rescaling.
|
||||
|
||||
@ -74,7 +75,7 @@ def compute_tsr_rescaling_factor(
|
||||
"""
|
||||
posinf_mask = torch.isposinf(snr)
|
||||
rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1)
|
||||
return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k
|
||||
return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k
|
||||
|
||||
|
||||
class TemporalScoreRescaling(io.ComfyNode):
|
||||
@ -125,7 +126,7 @@ class TemporalScoreRescaling(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput:
|
||||
tsr_variance = tsr_sigma**2
|
||||
tsr_variance = tsr_sigma ** 2
|
||||
|
||||
def temporal_score_rescaling(args):
|
||||
denoised = args["denoised"]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user