tune up mmaudio, nodes_eps

This commit is contained in:
doctorpangloss 2025-10-20 13:47:58 -07:00
parent f1016ef1c1
commit dc94081155
7 changed files with 82 additions and 78 deletions

View File

@ -4,7 +4,8 @@
import torch import torch
from torch import nn, sin, pow from torch import nn, sin, pow
from torch.nn import Parameter from torch.nn import Parameter
import comfy.model_management from ....model_management import cast_to
class Snake(nn.Module): class Snake(nn.Module):
''' '''
@ -22,6 +23,7 @@ class Snake(nn.Module):
>>> x = torch.randn(256) >>> x = torch.randn(256)
>>> x = a1(x) >>> x = a1(x)
''' '''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
''' '''
Initialization. Initialization.
@ -51,7 +53,7 @@ class Snake(nn.Module):
Applies the function to the input elementwise. Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa) Snake = x + 1/a * sin^2 (xa)
''' '''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] alpha = cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale: if self.alpha_logscale:
alpha = torch.exp(alpha) alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
@ -76,6 +78,7 @@ class SnakeBeta(nn.Module):
>>> x = torch.randn(256) >>> x = torch.randn(256)
>>> x = a1(x) >>> x = a1(x)
''' '''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
''' '''
Initialization. Initialization.
@ -110,8 +113,8 @@ class SnakeBeta(nn.Module):
Applies the function to the input elementwise. Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa) SnakeBeta = x + 1/b * sin^2 (xa)
''' '''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] alpha = cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) beta = cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale: if self.alpha_logscale:
alpha = torch.exp(alpha) alpha = torch.exp(alpha)
beta = torch.exp(beta) beta = torch.exp(beta)

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math import math
import comfy.model_management from ....model_management import cast_to
if 'sinc' in dir(torch): if 'sinc' in dir(torch):
sinc = torch.sinc sinc = torch.sinc
@ -27,13 +27,13 @@ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,
even = (kernel_size % 2 == 0) even = (kernel_size % 2 == 0)
half_size = kernel_size // 2 half_size = kernel_size // 2
#For kaiser window # For kaiser window
delta_f = 4 * half_width delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.: if A > 50.:
beta = 0.1102 * (A - 8.7) beta = 0.1102 * (A - 8.7)
elif A >= 21.: elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.)
else: else:
beta = 0. beta = 0.
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
@ -80,14 +80,14 @@ class LowPassFilter1d(nn.Module):
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter) self.register_buffer("filter", filter)
#input [B, C, T] # input [B, C, T]
def forward(self, x): def forward(self, x):
_, C, _ = x.shape _, C, _ = x.shape
if self.padding: if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode) mode=self.padding_mode)
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), out = F.conv1d(x, cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
stride=self.stride, groups=C) stride=self.stride, groups=C)
return out return out
@ -113,7 +113,7 @@ class UpSample1d(nn.Module):
x = F.pad(x, (self.pad, self.pad), mode='replicate') x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * F.conv_transpose1d( x = self.ratio * F.conv_transpose1d(
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C) x, cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
x = x[..., self.pad_left:-self.pad_right] x = x[..., self.pad_left:-self.pad_right]
return x return x
@ -134,6 +134,7 @@ class DownSample1d(nn.Module):
return xx return xx
class Activation1d(nn.Module): class Activation1d(nn.Module):
def __init__(self, def __init__(self,
activation, activation,

View File

@ -8,10 +8,6 @@ from .vae import VAE_16k
from .bigvgan import BigVGANVocoder from .bigvgan import BigVGANVocoder
import logging import logging
try:
import torchaudio
except:
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn): def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
return norm_fn(torch.clamp(x, min=clip_val) * C) return norm_fn(torch.clamp(x, min=clip_val) * C)
@ -21,6 +17,7 @@ def spectral_normalize_torch(magnitudes, norm_fn):
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn) output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
return output return output
class MelConverter(nn.Module): class MelConverter(nn.Module):
def __init__( def __init__(
@ -89,6 +86,7 @@ class MelConverter(nn.Module):
return spec return spec
class AudioAutoencoder(nn.Module): class AudioAutoencoder(nn.Module):
def __init__( def __init__(
@ -145,11 +143,13 @@ class AudioAutoencoder(nn.Module):
mel_decoded = self.vae.decode(z) mel_decoded = self.vae.decode(z)
audio = self.vocoder(mel_decoded) audio = self.vocoder(mel_decoded)
import torchaudio
audio = torchaudio.functional.resample(audio, 16000, 44100) audio = torchaudio.functional.resample(audio, 16000, 44100)
return audio return audio
@torch.no_grad() @torch.no_grad()
def encode(self, audio): def encode(self, audio):
import torchaudio
audio = audio.mean(dim=1) audio = audio.mean(dim=1)
audio = torchaudio.functional.resample(audio, 44100, 16000) audio = torchaudio.functional.resample(audio, 44100, 16000)
dist = self.encode_audio(audio) dist = self.encode_audio(audio)

View File

@ -9,12 +9,13 @@ import torch.nn as nn
from types import SimpleNamespace from types import SimpleNamespace
from . import activations from . import activations
from .alias_free_torch import Activation1d from .alias_free_torch import Activation1d
import comfy.ops from ....ops import disable_weight_init as ops
ops = comfy.ops.disable_weight_init
def get_padding(kernel_size, dilation=1): def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2) return int((kernel_size * dilation - dilation) / 2)
class AMPBlock1(torch.nn.Module): class AMPBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None): def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
@ -165,8 +166,8 @@ class BigVGANVocoder(torch.nn.Module):
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append( self.ups.append(
nn.ModuleList([ nn.ModuleList([
ops.ConvTranspose1d(h.upsample_initial_channel // (2**i), ops.ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
h.upsample_initial_channel // (2**(i + 1)), h.upsample_initial_channel // (2 ** (i + 1)),
k, k,
u, u,
padding=(k - u) // 2) padding=(k - u) // 2)
@ -175,7 +176,7 @@ class BigVGANVocoder(torch.nn.Module):
# residual blocks using anti-aliased multi-periodicity composition modules (AMP) # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList() self.resblocks = nn.ModuleList()
for i in range(len(self.ups)): for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2**(i + 1)) ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation)) self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
@ -193,7 +194,6 @@ class BigVGANVocoder(torch.nn.Module):
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3) self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
def forward(self, x): def forward(self, x):
# pre conv # pre conv
x = self.conv_pre(x) x = self.conv_pre(x)

View File

@ -1,4 +1,3 @@
import logging
from typing import Optional from typing import Optional
import torch import torch
@ -8,10 +7,8 @@ from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
Upsample1D, nonlinearity) Upsample1D, nonlinearity)
from .distributions import DiagonalGaussianDistribution from .distributions import DiagonalGaussianDistribution
import comfy.ops from ....ops import disable_weight_init as ops
ops = comfy.ops.disable_weight_init
log = logging.getLogger()
DATA_MEAN_80D = [ DATA_MEAN_80D = [
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927, -1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
@ -190,7 +187,7 @@ class Encoder1D(nn.Module):
self.attn_layers = attn_layers self.attn_layers = attn_layers
self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
in_ch_mult = (1, ) + tuple(ch_mult) in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult self.in_ch_mult = in_ch_mult
# downsampling # downsampling
self.down = nn.ModuleList() self.down = nn.ModuleList()
@ -355,4 +352,3 @@ def get_my_vae(name: str, **kwargs) -> VAE:
if name == '44k': if name == '44k':
return VAE_44k(**kwargs) return VAE_44k(**kwargs)
raise ValueError(f'Unknown model: {name}') raise ValueError(f'Unknown model: {name}')

View File

@ -1,17 +1,19 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import vae_attention from ...modules.diffusionmodules.model import vae_attention
import math import math
import comfy.ops from ....ops import disable_weight_init as ops
ops = comfy.ops.disable_weight_init
def nonlinearity(x): def nonlinearity(x):
# swish # swish
return torch.nn.functional.silu(x) / 0.596 return torch.nn.functional.silu(x) / 0.596
def mp_sum(a, b, t=0.5): def mp_sum(a, b, t=0.5):
return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2) return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
def normalize(x, dim=None, eps=1e-4): def normalize(x, dim=None, eps=1e-4):
if dim is None: if dim is None:
@ -20,6 +22,7 @@ def normalize(x, dim=None, eps=1e-4):
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype) return x / norm.to(x.dtype)
class ResnetBlock1D(nn.Module): class ResnetBlock1D(nn.Module):
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True): def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):

View File

@ -14,6 +14,7 @@ class EpsilonScaling(io.ComfyNode):
which can significantly improve sample quality. This implementation uses the "uniform schedule" which can significantly improve sample quality. This implementation uses the "uniform schedule"
recommended by the paper for its practicality and effectiveness. recommended by the paper for its practicality and effectiveness.
""" """
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
@ -125,7 +126,7 @@ class TemporalScoreRescaling(io.ComfyNode):
@classmethod @classmethod
def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput: def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput:
tsr_variance = tsr_sigma**2 tsr_variance = tsr_sigma ** 2
def temporal_score_rescaling(args): def temporal_score_rescaling(args):
denoised = args["denoised"] denoised = args["denoised"]