tune up mmaudio, nodes_eps

This commit is contained in:
doctorpangloss 2025-10-20 13:47:58 -07:00
parent f1016ef1c1
commit dc94081155
7 changed files with 82 additions and 78 deletions

View File

@ -4,7 +4,8 @@
import torch import torch
from torch import nn, sin, pow from torch import nn, sin, pow
from torch.nn import Parameter from torch.nn import Parameter
import comfy.model_management from ....model_management import cast_to
class Snake(nn.Module): class Snake(nn.Module):
''' '''
@ -22,6 +23,7 @@ class Snake(nn.Module):
>>> x = torch.randn(256) >>> x = torch.randn(256)
>>> x = a1(x) >>> x = a1(x)
''' '''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
''' '''
Initialization. Initialization.
@ -51,7 +53,7 @@ class Snake(nn.Module):
Applies the function to the input elementwise. Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa) Snake = x + 1/a * sin^2 (xa)
''' '''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] alpha = cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale: if self.alpha_logscale:
alpha = torch.exp(alpha) alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
@ -76,6 +78,7 @@ class SnakeBeta(nn.Module):
>>> x = torch.randn(256) >>> x = torch.randn(256)
>>> x = a1(x) >>> x = a1(x)
''' '''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
''' '''
Initialization. Initialization.
@ -110,8 +113,8 @@ class SnakeBeta(nn.Module):
Applies the function to the input elementwise. Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa) SnakeBeta = x + 1/b * sin^2 (xa)
''' '''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] alpha = cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) beta = cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale: if self.alpha_logscale:
alpha = torch.exp(alpha) alpha = torch.exp(alpha)
beta = torch.exp(beta) beta = torch.exp(beta)

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math import math
import comfy.model_management from ....model_management import cast_to
if 'sinc' in dir(torch): if 'sinc' in dir(torch):
sinc = torch.sinc sinc = torch.sinc
@ -23,17 +23,17 @@ else:
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html # https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory. # LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
even = (kernel_size % 2 == 0) even = (kernel_size % 2 == 0)
half_size = kernel_size // 2 half_size = kernel_size // 2
#For kaiser window # For kaiser window
delta_f = 4 * half_width delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.: if A > 50.:
beta = 0.1102 * (A - 8.7) beta = 0.1102 * (A - 8.7)
elif A >= 21.: elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.)
else: else:
beta = 0. beta = 0.
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
@ -80,14 +80,14 @@ class LowPassFilter1d(nn.Module):
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter) self.register_buffer("filter", filter)
#input [B, C, T] # input [B, C, T]
def forward(self, x): def forward(self, x):
_, C, _ = x.shape _, C, _ = x.shape
if self.padding: if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode) mode=self.padding_mode)
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), out = F.conv1d(x, cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
stride=self.stride, groups=C) stride=self.stride, groups=C)
return out return out
@ -113,7 +113,7 @@ class UpSample1d(nn.Module):
x = F.pad(x, (self.pad, self.pad), mode='replicate') x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * F.conv_transpose1d( x = self.ratio * F.conv_transpose1d(
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C) x, cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
x = x[..., self.pad_left:-self.pad_right] x = x[..., self.pad_left:-self.pad_right]
return x return x
@ -134,6 +134,7 @@ class DownSample1d(nn.Module):
return xx return xx
class Activation1d(nn.Module): class Activation1d(nn.Module):
def __init__(self, def __init__(self,
activation, activation,

View File

@ -8,10 +8,6 @@ from .vae import VAE_16k
from .bigvgan import BigVGANVocoder from .bigvgan import BigVGANVocoder
import logging import logging
try:
import torchaudio
except:
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn): def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
return norm_fn(torch.clamp(x, min=clip_val) * C) return norm_fn(torch.clamp(x, min=clip_val) * C)
@ -21,19 +17,20 @@ def spectral_normalize_torch(magnitudes, norm_fn):
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn) output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
return output return output
class MelConverter(nn.Module): class MelConverter(nn.Module):
def __init__( def __init__(
self, self,
*, *,
sampling_rate: float, sampling_rate: float,
n_fft: int, n_fft: int,
num_mels: int, num_mels: int,
hop_size: int, hop_size: int,
win_size: int, win_size: int,
fmin: float, fmin: float,
fmax: float, fmax: float,
norm_fn, norm_fn,
): ):
super().__init__() super().__init__()
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
@ -89,26 +86,27 @@ class MelConverter(nn.Module):
return spec return spec
class AudioAutoencoder(nn.Module): class AudioAutoencoder(nn.Module):
def __init__( def __init__(
self, self,
*, *,
# ckpt_path: str, # ckpt_path: str,
mode=Literal['16k', '44k'], mode=Literal['16k', '44k'],
need_vae_encoder: bool = True, need_vae_encoder: bool = True,
): ):
super().__init__() super().__init__()
assert mode == "16k", "Only 16k mode is supported currently." assert mode == "16k", "Only 16k mode is supported currently."
self.mel_converter = MelConverter(sampling_rate=16_000, self.mel_converter = MelConverter(sampling_rate=16_000,
n_fft=1024, n_fft=1024,
num_mels=80, num_mels=80,
hop_size=256, hop_size=256,
win_size=1024, win_size=1024,
fmin=0, fmin=0,
fmax=8_000, fmax=8_000,
norm_fn=torch.log10) norm_fn=torch.log10)
self.vae = VAE_16k().eval() self.vae = VAE_16k().eval()
@ -145,11 +143,13 @@ class AudioAutoencoder(nn.Module):
mel_decoded = self.vae.decode(z) mel_decoded = self.vae.decode(z)
audio = self.vocoder(mel_decoded) audio = self.vocoder(mel_decoded)
import torchaudio
audio = torchaudio.functional.resample(audio, 16000, 44100) audio = torchaudio.functional.resample(audio, 16000, 44100)
return audio return audio
@torch.no_grad() @torch.no_grad()
def encode(self, audio): def encode(self, audio):
import torchaudio
audio = audio.mean(dim=1) audio = audio.mean(dim=1)
audio = torchaudio.functional.resample(audio, 44100, 16000) audio = torchaudio.functional.resample(audio, 44100, 16000)
dist = self.encode_audio(audio) dist = self.encode_audio(audio)

View File

@ -9,12 +9,13 @@ import torch.nn as nn
from types import SimpleNamespace from types import SimpleNamespace
from . import activations from . import activations
from .alias_free_torch import Activation1d from .alias_free_torch import Activation1d
import comfy.ops from ....ops import disable_weight_init as ops
ops = comfy.ops.disable_weight_init
def get_padding(kernel_size, dilation=1): def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2) return int((kernel_size * dilation - dilation) / 2)
class AMPBlock1(torch.nn.Module): class AMPBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None): def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
@ -22,19 +23,19 @@ class AMPBlock1(torch.nn.Module):
self.h = h self.h = h
self.convs1 = nn.ModuleList([ self.convs1 = nn.ModuleList([
ops.Conv1d(channels, ops.Conv1d(channels,
channels, channels,
kernel_size, kernel_size,
1, 1,
dilation=dilation[0], dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])), padding=get_padding(kernel_size, dilation[0])),
ops.Conv1d(channels, ops.Conv1d(channels,
channels, channels,
kernel_size, kernel_size,
1, 1,
dilation=dilation[1], dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])), padding=get_padding(kernel_size, dilation[1])),
ops.Conv1d(channels, ops.Conv1d(channels,
channels, channels,
kernel_size, kernel_size,
1, 1,
@ -43,19 +44,19 @@ class AMPBlock1(torch.nn.Module):
]) ])
self.convs2 = nn.ModuleList([ self.convs2 = nn.ModuleList([
ops.Conv1d(channels, ops.Conv1d(channels,
channels, channels,
kernel_size, kernel_size,
1, 1,
dilation=1, dilation=1,
padding=get_padding(kernel_size, 1)), padding=get_padding(kernel_size, 1)),
ops.Conv1d(channels, ops.Conv1d(channels,
channels, channels,
kernel_size, kernel_size,
1, 1,
dilation=1, dilation=1,
padding=get_padding(kernel_size, 1)), padding=get_padding(kernel_size, 1)),
ops.Conv1d(channels, ops.Conv1d(channels,
channels, channels,
kernel_size, kernel_size,
1, 1,
@ -101,13 +102,13 @@ class AMPBlock2(torch.nn.Module):
self.h = h self.h = h
self.convs = nn.ModuleList([ self.convs = nn.ModuleList([
ops.Conv1d(channels, ops.Conv1d(channels,
channels, channels,
kernel_size, kernel_size,
1, 1,
dilation=dilation[0], dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0])), padding=get_padding(kernel_size, dilation[0])),
ops.Conv1d(channels, ops.Conv1d(channels,
channels, channels,
kernel_size, kernel_size,
1, 1,
@ -165,8 +166,8 @@ class BigVGANVocoder(torch.nn.Module):
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append( self.ups.append(
nn.ModuleList([ nn.ModuleList([
ops.ConvTranspose1d(h.upsample_initial_channel // (2**i), ops.ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
h.upsample_initial_channel // (2**(i + 1)), h.upsample_initial_channel // (2 ** (i + 1)),
k, k,
u, u,
padding=(k - u) // 2) padding=(k - u) // 2)
@ -175,7 +176,7 @@ class BigVGANVocoder(torch.nn.Module):
# residual blocks using anti-aliased multi-periodicity composition modules (AMP) # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList() self.resblocks = nn.ModuleList()
for i in range(len(self.ups)): for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2**(i + 1)) ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation)) self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
@ -193,7 +194,6 @@ class BigVGANVocoder(torch.nn.Module):
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3) self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
def forward(self, x): def forward(self, x):
# pre conv # pre conv
x = self.conv_pre(x) x = self.conv_pre(x)

View File

@ -1,17 +1,14 @@
import logging
from typing import Optional from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D, from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
Upsample1D, nonlinearity) Upsample1D, nonlinearity)
from .distributions import DiagonalGaussianDistribution from .distributions import DiagonalGaussianDistribution
import comfy.ops from ....ops import disable_weight_init as ops
ops = comfy.ops.disable_weight_init
log = logging.getLogger()
DATA_MEAN_80D = [ DATA_MEAN_80D = [
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927, -1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
@ -68,11 +65,11 @@ DATA_STD_128D = [
class VAE(nn.Module): class VAE(nn.Module):
def __init__( def __init__(
self, self,
*, *,
data_dim: int, data_dim: int,
embed_dim: int, embed_dim: int,
hidden_dim: int, hidden_dim: int,
): ):
super().__init__() super().__init__()
@ -135,12 +132,12 @@ class VAE(nn.Module):
return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device) return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
sample_posterior: bool = True, sample_posterior: bool = True,
rng: Optional[torch.Generator] = None, rng: Optional[torch.Generator] = None,
normalize: bool = True, normalize: bool = True,
unnormalize: bool = True, unnormalize: bool = True,
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]: ) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
posterior = self.encode(x, normalize=normalize) posterior = self.encode(x, normalize=normalize)
@ -190,7 +187,7 @@ class Encoder1D(nn.Module):
self.attn_layers = attn_layers self.attn_layers = attn_layers
self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
in_ch_mult = (1, ) + tuple(ch_mult) in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult self.in_ch_mult = in_ch_mult
# downsampling # downsampling
self.down = nn.ModuleList() self.down = nn.ModuleList()
@ -229,8 +226,8 @@ class Encoder1D(nn.Module):
# end # end
self.conv_out = ops.Conv1d(block_in, self.conv_out = ops.Conv1d(block_in,
2 * embed_dim if double_z else embed_dim, 2 * embed_dim if double_z else embed_dim,
kernel_size=kernel_size, padding=kernel_size // 2, bias=False) kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.learnable_gain = nn.Parameter(torch.zeros([])) self.learnable_gain = nn.Parameter(torch.zeros([]))
@ -355,4 +352,3 @@ def get_my_vae(name: str, **kwargs) -> VAE:
if name == '44k': if name == '44k':
return VAE_44k(**kwargs) return VAE_44k(**kwargs)
raise ValueError(f'Unknown model: {name}') raise ValueError(f'Unknown model: {name}')

View File

@ -1,17 +1,19 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import vae_attention from ...modules.diffusionmodules.model import vae_attention
import math import math
import comfy.ops from ....ops import disable_weight_init as ops
ops = comfy.ops.disable_weight_init
def nonlinearity(x): def nonlinearity(x):
# swish # swish
return torch.nn.functional.silu(x) / 0.596 return torch.nn.functional.silu(x) / 0.596
def mp_sum(a, b, t=0.5): def mp_sum(a, b, t=0.5):
return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2) return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
def normalize(x, dim=None, eps=1e-4): def normalize(x, dim=None, eps=1e-4):
if dim is None: if dim is None:
@ -20,6 +22,7 @@ def normalize(x, dim=None, eps=1e-4):
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype) return x / norm.to(x.dtype)
class ResnetBlock1D(nn.Module): class ResnetBlock1D(nn.Module):
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True): def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):

View File

@ -14,6 +14,7 @@ class EpsilonScaling(io.ComfyNode):
which can significantly improve sample quality. This implementation uses the "uniform schedule" which can significantly improve sample quality. This implementation uses the "uniform schedule"
recommended by the paper for its practicality and effectiveness. recommended by the paper for its practicality and effectiveness.
""" """
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
@ -66,7 +67,7 @@ class EpsilonScaling(io.ComfyNode):
def compute_tsr_rescaling_factor( def compute_tsr_rescaling_factor(
snr: torch.Tensor, tsr_k: float, tsr_variance: float snr: torch.Tensor, tsr_k: float, tsr_variance: float
) -> torch.Tensor: ) -> torch.Tensor:
"""Compute the rescaling score ratio in Temporal Score Rescaling. """Compute the rescaling score ratio in Temporal Score Rescaling.
@ -74,7 +75,7 @@ def compute_tsr_rescaling_factor(
""" """
posinf_mask = torch.isposinf(snr) posinf_mask = torch.isposinf(snr)
rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1) rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1)
return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k
class TemporalScoreRescaling(io.ComfyNode): class TemporalScoreRescaling(io.ComfyNode):
@ -125,7 +126,7 @@ class TemporalScoreRescaling(io.ComfyNode):
@classmethod @classmethod
def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput: def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput:
tsr_variance = tsr_sigma**2 tsr_variance = tsr_sigma ** 2
def temporal_score_rescaling(args): def temporal_score_rescaling(args):
denoised = args["denoised"] denoised = args["denoised"]