tune up mmaudio, nodes_eps

This commit is contained in:
doctorpangloss 2025-10-20 13:47:58 -07:00
parent f1016ef1c1
commit dc94081155
7 changed files with 82 additions and 78 deletions

View File

@ -4,7 +4,8 @@
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
import comfy.model_management
from ....model_management import cast_to
class Snake(nn.Module):
'''
@ -22,6 +23,7 @@ class Snake(nn.Module):
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
@ -51,7 +53,7 @@ class Snake(nn.Module):
Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa)
'''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
alpha = cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
@ -76,6 +78,7 @@ class SnakeBeta(nn.Module):
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
@ -110,8 +113,8 @@ class SnakeBeta(nn.Module):
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
'''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
alpha = cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import comfy.model_management
from ....model_management import cast_to
if 'sinc' in dir(torch):
sinc = torch.sinc
@ -23,17 +23,17 @@ else:
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
even = (kernel_size % 2 == 0)
half_size = kernel_size // 2
#For kaiser window
# For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.:
beta = 0.1102 * (A - 8.7)
elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.)
else:
beta = 0.
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
@ -80,14 +80,14 @@ class LowPassFilter1d(nn.Module):
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
#input [B, C, T]
# input [B, C, T]
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode)
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
out = F.conv1d(x, cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
stride=self.stride, groups=C)
return out
@ -113,7 +113,7 @@ class UpSample1d(nn.Module):
x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * F.conv_transpose1d(
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
x, cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
x = x[..., self.pad_left:-self.pad_right]
return x
@ -134,6 +134,7 @@ class DownSample1d(nn.Module):
return xx
class Activation1d(nn.Module):
def __init__(self,
activation,

View File

@ -8,10 +8,6 @@ from .vae import VAE_16k
from .bigvgan import BigVGANVocoder
import logging
try:
import torchaudio
except:
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
return norm_fn(torch.clamp(x, min=clip_val) * C)
@ -21,19 +17,20 @@ def spectral_normalize_torch(magnitudes, norm_fn):
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
return output
class MelConverter(nn.Module):
def __init__(
self,
*,
sampling_rate: float,
n_fft: int,
num_mels: int,
hop_size: int,
win_size: int,
fmin: float,
fmax: float,
norm_fn,
self,
*,
sampling_rate: float,
n_fft: int,
num_mels: int,
hop_size: int,
win_size: int,
fmin: float,
fmax: float,
norm_fn,
):
super().__init__()
self.sampling_rate = sampling_rate
@ -89,26 +86,27 @@ class MelConverter(nn.Module):
return spec
class AudioAutoencoder(nn.Module):
def __init__(
self,
*,
# ckpt_path: str,
mode=Literal['16k', '44k'],
need_vae_encoder: bool = True,
self,
*,
# ckpt_path: str,
mode=Literal['16k', '44k'],
need_vae_encoder: bool = True,
):
super().__init__()
assert mode == "16k", "Only 16k mode is supported currently."
self.mel_converter = MelConverter(sampling_rate=16_000,
n_fft=1024,
num_mels=80,
hop_size=256,
win_size=1024,
fmin=0,
fmax=8_000,
norm_fn=torch.log10)
n_fft=1024,
num_mels=80,
hop_size=256,
win_size=1024,
fmin=0,
fmax=8_000,
norm_fn=torch.log10)
self.vae = VAE_16k().eval()
@ -145,11 +143,13 @@ class AudioAutoencoder(nn.Module):
mel_decoded = self.vae.decode(z)
audio = self.vocoder(mel_decoded)
import torchaudio
audio = torchaudio.functional.resample(audio, 16000, 44100)
return audio
@torch.no_grad()
def encode(self, audio):
import torchaudio
audio = audio.mean(dim=1)
audio = torchaudio.functional.resample(audio, 44100, 16000)
dist = self.encode_audio(audio)

View File

@ -9,12 +9,13 @@ import torch.nn as nn
from types import SimpleNamespace
from . import activations
from .alias_free_torch import Activation1d
import comfy.ops
ops = comfy.ops.disable_weight_init
from ....ops import disable_weight_init as ops
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class AMPBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
@ -22,19 +23,19 @@ class AMPBlock1(torch.nn.Module):
self.h = h
self.convs1 = nn.ModuleList([
ops.Conv1d(channels,
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])),
ops.Conv1d(channels,
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])),
ops.Conv1d(channels,
ops.Conv1d(channels,
channels,
kernel_size,
1,
@ -43,19 +44,19 @@ class AMPBlock1(torch.nn.Module):
])
self.convs2 = nn.ModuleList([
ops.Conv1d(channels,
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)),
ops.Conv1d(channels,
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)),
ops.Conv1d(channels,
ops.Conv1d(channels,
channels,
kernel_size,
1,
@ -101,13 +102,13 @@ class AMPBlock2(torch.nn.Module):
self.h = h
self.convs = nn.ModuleList([
ops.Conv1d(channels,
ops.Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])),
ops.Conv1d(channels,
ops.Conv1d(channels,
channels,
kernel_size,
1,
@ -165,8 +166,8 @@ class BigVGANVocoder(torch.nn.Module):
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
nn.ModuleList([
ops.ConvTranspose1d(h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2**(i + 1)),
ops.ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
h.upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2)
@ -175,7 +176,7 @@ class BigVGANVocoder(torch.nn.Module):
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2**(i + 1))
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
@ -193,7 +194,6 @@ class BigVGANVocoder(torch.nn.Module):
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
def forward(self, x):
# pre conv
x = self.conv_pre(x)

View File

@ -1,17 +1,14 @@
import logging
from typing import Optional
import torch
import torch.nn as nn
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
Upsample1D, nonlinearity)
Upsample1D, nonlinearity)
from .distributions import DiagonalGaussianDistribution
import comfy.ops
ops = comfy.ops.disable_weight_init
from ....ops import disable_weight_init as ops
log = logging.getLogger()
DATA_MEAN_80D = [
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
@ -68,11 +65,11 @@ DATA_STD_128D = [
class VAE(nn.Module):
def __init__(
self,
*,
data_dim: int,
embed_dim: int,
hidden_dim: int,
self,
*,
data_dim: int,
embed_dim: int,
hidden_dim: int,
):
super().__init__()
@ -135,12 +132,12 @@ class VAE(nn.Module):
return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)
def forward(
self,
x: torch.Tensor,
sample_posterior: bool = True,
rng: Optional[torch.Generator] = None,
normalize: bool = True,
unnormalize: bool = True,
self,
x: torch.Tensor,
sample_posterior: bool = True,
rng: Optional[torch.Generator] = None,
normalize: bool = True,
unnormalize: bool = True,
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
posterior = self.encode(x, normalize=normalize)
@ -190,7 +187,7 @@ class Encoder1D(nn.Module):
self.attn_layers = attn_layers
self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
in_ch_mult = (1, ) + tuple(ch_mult)
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
# downsampling
self.down = nn.ModuleList()
@ -229,8 +226,8 @@ class Encoder1D(nn.Module):
# end
self.conv_out = ops.Conv1d(block_in,
2 * embed_dim if double_z else embed_dim,
kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
2 * embed_dim if double_z else embed_dim,
kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.learnable_gain = nn.Parameter(torch.zeros([]))
@ -355,4 +352,3 @@ def get_my_vae(name: str, **kwargs) -> VAE:
if name == '44k':
return VAE_44k(**kwargs)
raise ValueError(f'Unknown model: {name}')

View File

@ -1,17 +1,19 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import vae_attention
from ...modules.diffusionmodules.model import vae_attention
import math
import comfy.ops
ops = comfy.ops.disable_weight_init
from ....ops import disable_weight_init as ops
def nonlinearity(x):
# swish
return torch.nn.functional.silu(x) / 0.596
def mp_sum(a, b, t=0.5):
return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2)
return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
def normalize(x, dim=None, eps=1e-4):
if dim is None:
@ -20,6 +22,7 @@ def normalize(x, dim=None, eps=1e-4):
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
class ResnetBlock1D(nn.Module):
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):

View File

@ -14,6 +14,7 @@ class EpsilonScaling(io.ComfyNode):
which can significantly improve sample quality. This implementation uses the "uniform schedule"
recommended by the paper for its practicality and effectiveness.
"""
@classmethod
def define_schema(cls):
return io.Schema(
@ -66,7 +67,7 @@ class EpsilonScaling(io.ComfyNode):
def compute_tsr_rescaling_factor(
snr: torch.Tensor, tsr_k: float, tsr_variance: float
snr: torch.Tensor, tsr_k: float, tsr_variance: float
) -> torch.Tensor:
"""Compute the rescaling score ratio in Temporal Score Rescaling.
@ -74,7 +75,7 @@ def compute_tsr_rescaling_factor(
"""
posinf_mask = torch.isposinf(snr)
rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1)
return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k
return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k
class TemporalScoreRescaling(io.ComfyNode):
@ -125,7 +126,7 @@ class TemporalScoreRescaling(io.ComfyNode):
@classmethod
def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput:
tsr_variance = tsr_sigma**2
tsr_variance = tsr_sigma ** 2
def temporal_score_rescaling(args):
denoised = args["denoised"]