tune up mmaudio, nodes_eps

This commit is contained in:
doctorpangloss 2025-10-20 13:47:58 -07:00
parent f1016ef1c1
commit dc94081155
7 changed files with 82 additions and 78 deletions

View File

@ -4,7 +4,8 @@
import torch import torch
from torch import nn, sin, pow from torch import nn, sin, pow
from torch.nn import Parameter from torch.nn import Parameter
import comfy.model_management from ....model_management import cast_to
class Snake(nn.Module): class Snake(nn.Module):
''' '''
@ -22,6 +23,7 @@ class Snake(nn.Module):
>>> x = torch.randn(256) >>> x = torch.randn(256)
>>> x = a1(x) >>> x = a1(x)
''' '''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
''' '''
Initialization. Initialization.
@ -51,7 +53,7 @@ class Snake(nn.Module):
Applies the function to the input elementwise. Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa) Snake = x + 1/a * sin^2 (xa)
''' '''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] alpha = cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale: if self.alpha_logscale:
alpha = torch.exp(alpha) alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
@ -76,6 +78,7 @@ class SnakeBeta(nn.Module):
>>> x = torch.randn(256) >>> x = torch.randn(256)
>>> x = a1(x) >>> x = a1(x)
''' '''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
''' '''
Initialization. Initialization.
@ -110,8 +113,8 @@ class SnakeBeta(nn.Module):
Applies the function to the input elementwise. Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa) SnakeBeta = x + 1/b * sin^2 (xa)
''' '''
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] alpha = cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) beta = cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale: if self.alpha_logscale:
alpha = torch.exp(alpha) alpha = torch.exp(alpha)
beta = torch.exp(beta) beta = torch.exp(beta)

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math import math
import comfy.model_management from ....model_management import cast_to
if 'sinc' in dir(torch): if 'sinc' in dir(torch):
sinc = torch.sinc sinc = torch.sinc
@ -87,7 +87,7 @@ class LowPassFilter1d(nn.Module):
if self.padding: if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode) mode=self.padding_mode)
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), out = F.conv1d(x, cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
stride=self.stride, groups=C) stride=self.stride, groups=C)
return out return out
@ -113,7 +113,7 @@ class UpSample1d(nn.Module):
x = F.pad(x, (self.pad, self.pad), mode='replicate') x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * F.conv_transpose1d( x = self.ratio * F.conv_transpose1d(
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C) x, cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
x = x[..., self.pad_left:-self.pad_right] x = x[..., self.pad_left:-self.pad_right]
return x return x
@ -134,6 +134,7 @@ class DownSample1d(nn.Module):
return xx return xx
class Activation1d(nn.Module): class Activation1d(nn.Module):
def __init__(self, def __init__(self,
activation, activation,

View File

@ -8,10 +8,6 @@ from .vae import VAE_16k
from .bigvgan import BigVGANVocoder from .bigvgan import BigVGANVocoder
import logging import logging
try:
import torchaudio
except:
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn): def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
return norm_fn(torch.clamp(x, min=clip_val) * C) return norm_fn(torch.clamp(x, min=clip_val) * C)
@ -21,6 +17,7 @@ def spectral_normalize_torch(magnitudes, norm_fn):
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn) output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
return output return output
class MelConverter(nn.Module): class MelConverter(nn.Module):
def __init__( def __init__(
@ -89,6 +86,7 @@ class MelConverter(nn.Module):
return spec return spec
class AudioAutoencoder(nn.Module): class AudioAutoencoder(nn.Module):
def __init__( def __init__(
@ -145,11 +143,13 @@ class AudioAutoencoder(nn.Module):
mel_decoded = self.vae.decode(z) mel_decoded = self.vae.decode(z)
audio = self.vocoder(mel_decoded) audio = self.vocoder(mel_decoded)
import torchaudio
audio = torchaudio.functional.resample(audio, 16000, 44100) audio = torchaudio.functional.resample(audio, 16000, 44100)
return audio return audio
@torch.no_grad() @torch.no_grad()
def encode(self, audio): def encode(self, audio):
import torchaudio
audio = audio.mean(dim=1) audio = audio.mean(dim=1)
audio = torchaudio.functional.resample(audio, 44100, 16000) audio = torchaudio.functional.resample(audio, 44100, 16000)
dist = self.encode_audio(audio) dist = self.encode_audio(audio)

View File

@ -9,12 +9,13 @@ import torch.nn as nn
from types import SimpleNamespace from types import SimpleNamespace
from . import activations from . import activations
from .alias_free_torch import Activation1d from .alias_free_torch import Activation1d
import comfy.ops from ....ops import disable_weight_init as ops
ops = comfy.ops.disable_weight_init
def get_padding(kernel_size, dilation=1): def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2) return int((kernel_size * dilation - dilation) / 2)
class AMPBlock1(torch.nn.Module): class AMPBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None): def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
@ -193,7 +194,6 @@ class BigVGANVocoder(torch.nn.Module):
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3) self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
def forward(self, x): def forward(self, x):
# pre conv # pre conv
x = self.conv_pre(x) x = self.conv_pre(x)

View File

@ -1,4 +1,3 @@
import logging
from typing import Optional from typing import Optional
import torch import torch
@ -8,10 +7,8 @@ from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
Upsample1D, nonlinearity) Upsample1D, nonlinearity)
from .distributions import DiagonalGaussianDistribution from .distributions import DiagonalGaussianDistribution
import comfy.ops from ....ops import disable_weight_init as ops
ops = comfy.ops.disable_weight_init
log = logging.getLogger()
DATA_MEAN_80D = [ DATA_MEAN_80D = [
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927, -1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
@ -355,4 +352,3 @@ def get_my_vae(name: str, **kwargs) -> VAE:
if name == '44k': if name == '44k':
return VAE_44k(**kwargs) return VAE_44k(**kwargs)
raise ValueError(f'Unknown model: {name}') raise ValueError(f'Unknown model: {name}')

View File

@ -1,18 +1,20 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import vae_attention from ...modules.diffusionmodules.model import vae_attention
import math import math
import comfy.ops from ....ops import disable_weight_init as ops
ops = comfy.ops.disable_weight_init
def nonlinearity(x): def nonlinearity(x):
# swish # swish
return torch.nn.functional.silu(x) / 0.596 return torch.nn.functional.silu(x) / 0.596
def mp_sum(a, b, t=0.5): def mp_sum(a, b, t=0.5):
return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2) return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2)
def normalize(x, dim=None, eps=1e-4): def normalize(x, dim=None, eps=1e-4):
if dim is None: if dim is None:
dim = list(range(1, x.ndim)) dim = list(range(1, x.ndim))
@ -20,6 +22,7 @@ def normalize(x, dim=None, eps=1e-4):
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype) return x / norm.to(x.dtype)
class ResnetBlock1D(nn.Module): class ResnetBlock1D(nn.Module):
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True): def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):

View File

@ -14,6 +14,7 @@ class EpsilonScaling(io.ComfyNode):
which can significantly improve sample quality. This implementation uses the "uniform schedule" which can significantly improve sample quality. This implementation uses the "uniform schedule"
recommended by the paper for its practicality and effectiveness. recommended by the paper for its practicality and effectiveness.
""" """
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(