From dc9408115574eef905cd0739836c7c8fe19948da Mon Sep 17 00:00:00 2001 From: doctorpangloss <2229300+doctorpangloss@users.noreply.github.com> Date: Mon, 20 Oct 2025 13:47:58 -0700 Subject: [PATCH] tune up mmaudio, nodes_eps --- comfy/ldm/mmaudio/vae/activations.py | 11 +++-- comfy/ldm/mmaudio/vae/alias_free_torch.py | 15 ++++--- comfy/ldm/mmaudio/vae/autoencoder.py | 52 +++++++++++------------ comfy/ldm/mmaudio/vae/bigvgan.py | 28 ++++++------ comfy/ldm/mmaudio/vae/vae.py | 36 +++++++--------- comfy/ldm/mmaudio/vae/vae_modules.py | 11 +++-- comfy_extras/nodes/nodes_eps.py | 7 +-- 7 files changed, 82 insertions(+), 78 deletions(-) diff --git a/comfy/ldm/mmaudio/vae/activations.py b/comfy/ldm/mmaudio/vae/activations.py index db9192e3e..81aa6ab3c 100644 --- a/comfy/ldm/mmaudio/vae/activations.py +++ b/comfy/ldm/mmaudio/vae/activations.py @@ -4,7 +4,8 @@ import torch from torch import nn, sin, pow from torch.nn import Parameter -import comfy.model_management +from ....model_management import cast_to + class Snake(nn.Module): ''' @@ -22,6 +23,7 @@ class Snake(nn.Module): >>> x = torch.randn(256) >>> x = a1(x) ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): ''' Initialization. @@ -51,7 +53,7 @@ class Snake(nn.Module): Applies the function to the input elementwise. Snake ∶= x + 1/a * sin^2 (xa) ''' - alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + alpha = cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] if self.alpha_logscale: alpha = torch.exp(alpha) x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) @@ -76,6 +78,7 @@ class SnakeBeta(nn.Module): >>> x = torch.randn(256) >>> x = a1(x) ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): ''' Initialization. @@ -110,8 +113,8 @@ class SnakeBeta(nn.Module): Applies the function to the input elementwise. SnakeBeta ∶= x + 1/b * sin^2 (xa) ''' - alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] - beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) + alpha = cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) if self.alpha_logscale: alpha = torch.exp(alpha) beta = torch.exp(beta) diff --git a/comfy/ldm/mmaudio/vae/alias_free_torch.py b/comfy/ldm/mmaudio/vae/alias_free_torch.py index 35c70b897..6babd38a5 100644 --- a/comfy/ldm/mmaudio/vae/alias_free_torch.py +++ b/comfy/ldm/mmaudio/vae/alias_free_torch.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import math -import comfy.model_management +from ....model_management import cast_to if 'sinc' in dir(torch): sinc = torch.sinc @@ -23,17 +23,17 @@ else: # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License # https://adefossez.github.io/julius/julius/lowpass.html # LICENSE is in incl_licenses directory. -def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] even = (kernel_size % 2 == 0) half_size = kernel_size // 2 - #For kaiser window + # For kaiser window delta_f = 4 * half_width A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 if A > 50.: beta = 0.1102 * (A - 8.7) elif A >= 21.: - beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.) else: beta = 0. window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) @@ -80,14 +80,14 @@ class LowPassFilter1d(nn.Module): filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) self.register_buffer("filter", filter) - #input [B, C, T] + # input [B, C, T] def forward(self, x): _, C, _ = x.shape if self.padding: x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) - out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), + out = F.conv1d(x, cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C) return out @@ -113,7 +113,7 @@ class UpSample1d(nn.Module): x = F.pad(x, (self.pad, self.pad), mode='replicate') x = self.ratio * F.conv_transpose1d( - x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C) + x, cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C) x = x[..., self.pad_left:-self.pad_right] return x @@ -134,6 +134,7 @@ class DownSample1d(nn.Module): return xx + class Activation1d(nn.Module): def __init__(self, activation, diff --git a/comfy/ldm/mmaudio/vae/autoencoder.py b/comfy/ldm/mmaudio/vae/autoencoder.py index cbb9de302..f7f25b4df 100644 --- a/comfy/ldm/mmaudio/vae/autoencoder.py +++ b/comfy/ldm/mmaudio/vae/autoencoder.py @@ -8,10 +8,6 @@ from .vae import VAE_16k from .bigvgan import BigVGANVocoder import logging -try: - import torchaudio -except: - logging.warning("torchaudio missing, MMAudio VAE model will be broken") def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn): return norm_fn(torch.clamp(x, min=clip_val) * C) @@ -21,19 +17,20 @@ def spectral_normalize_torch(magnitudes, norm_fn): output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn) return output + class MelConverter(nn.Module): def __init__( - self, - *, - sampling_rate: float, - n_fft: int, - num_mels: int, - hop_size: int, - win_size: int, - fmin: float, - fmax: float, - norm_fn, + self, + *, + sampling_rate: float, + n_fft: int, + num_mels: int, + hop_size: int, + win_size: int, + fmin: float, + fmax: float, + norm_fn, ): super().__init__() self.sampling_rate = sampling_rate @@ -89,26 +86,27 @@ class MelConverter(nn.Module): return spec + class AudioAutoencoder(nn.Module): def __init__( - self, - *, - # ckpt_path: str, - mode=Literal['16k', '44k'], - need_vae_encoder: bool = True, + self, + *, + # ckpt_path: str, + mode=Literal['16k', '44k'], + need_vae_encoder: bool = True, ): super().__init__() assert mode == "16k", "Only 16k mode is supported currently." self.mel_converter = MelConverter(sampling_rate=16_000, - n_fft=1024, - num_mels=80, - hop_size=256, - win_size=1024, - fmin=0, - fmax=8_000, - norm_fn=torch.log10) + n_fft=1024, + num_mels=80, + hop_size=256, + win_size=1024, + fmin=0, + fmax=8_000, + norm_fn=torch.log10) self.vae = VAE_16k().eval() @@ -145,11 +143,13 @@ class AudioAutoencoder(nn.Module): mel_decoded = self.vae.decode(z) audio = self.vocoder(mel_decoded) + import torchaudio audio = torchaudio.functional.resample(audio, 16000, 44100) return audio @torch.no_grad() def encode(self, audio): + import torchaudio audio = audio.mean(dim=1) audio = torchaudio.functional.resample(audio, 44100, 16000) dist = self.encode_audio(audio) diff --git a/comfy/ldm/mmaudio/vae/bigvgan.py b/comfy/ldm/mmaudio/vae/bigvgan.py index 3a24337f6..24f9b0241 100644 --- a/comfy/ldm/mmaudio/vae/bigvgan.py +++ b/comfy/ldm/mmaudio/vae/bigvgan.py @@ -9,12 +9,13 @@ import torch.nn as nn from types import SimpleNamespace from . import activations from .alias_free_torch import Activation1d -import comfy.ops -ops = comfy.ops.disable_weight_init +from ....ops import disable_weight_init as ops + def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) + class AMPBlock1(torch.nn.Module): def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None): @@ -22,19 +23,19 @@ class AMPBlock1(torch.nn.Module): self.h = h self.convs1 = nn.ModuleList([ - ops.Conv1d(channels, + ops.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0])), - ops.Conv1d(channels, + ops.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1])), - ops.Conv1d(channels, + ops.Conv1d(channels, channels, kernel_size, 1, @@ -43,19 +44,19 @@ class AMPBlock1(torch.nn.Module): ]) self.convs2 = nn.ModuleList([ - ops.Conv1d(channels, + ops.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), - ops.Conv1d(channels, + ops.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), - ops.Conv1d(channels, + ops.Conv1d(channels, channels, kernel_size, 1, @@ -101,13 +102,13 @@ class AMPBlock2(torch.nn.Module): self.h = h self.convs = nn.ModuleList([ - ops.Conv1d(channels, + ops.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0])), - ops.Conv1d(channels, + ops.Conv1d(channels, channels, kernel_size, 1, @@ -165,8 +166,8 @@ class BigVGANVocoder(torch.nn.Module): for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): self.ups.append( nn.ModuleList([ - ops.ConvTranspose1d(h.upsample_initial_channel // (2**i), - h.upsample_initial_channel // (2**(i + 1)), + ops.ConvTranspose1d(h.upsample_initial_channel // (2 ** i), + h.upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2) @@ -175,7 +176,7 @@ class BigVGANVocoder(torch.nn.Module): # residual blocks using anti-aliased multi-periodicity composition modules (AMP) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): - ch = h.upsample_initial_channel // (2**(i + 1)) + ch = h.upsample_initial_channel // (2 ** (i + 1)) for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): self.resblocks.append(resblock(h, ch, k, d, activation=h.activation)) @@ -193,7 +194,6 @@ class BigVGANVocoder(torch.nn.Module): self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3) - def forward(self, x): # pre conv x = self.conv_pre(x) diff --git a/comfy/ldm/mmaudio/vae/vae.py b/comfy/ldm/mmaudio/vae/vae.py index 62f24606c..fc8773d69 100644 --- a/comfy/ldm/mmaudio/vae/vae.py +++ b/comfy/ldm/mmaudio/vae/vae.py @@ -1,17 +1,14 @@ -import logging from typing import Optional import torch import torch.nn as nn from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D, - Upsample1D, nonlinearity) + Upsample1D, nonlinearity) from .distributions import DiagonalGaussianDistribution -import comfy.ops -ops = comfy.ops.disable_weight_init +from ....ops import disable_weight_init as ops -log = logging.getLogger() DATA_MEAN_80D = [ -1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927, @@ -68,11 +65,11 @@ DATA_STD_128D = [ class VAE(nn.Module): def __init__( - self, - *, - data_dim: int, - embed_dim: int, - hidden_dim: int, + self, + *, + data_dim: int, + embed_dim: int, + hidden_dim: int, ): super().__init__() @@ -135,12 +132,12 @@ class VAE(nn.Module): return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device) def forward( - self, - x: torch.Tensor, - sample_posterior: bool = True, - rng: Optional[torch.Generator] = None, - normalize: bool = True, - unnormalize: bool = True, + self, + x: torch.Tensor, + sample_posterior: bool = True, + rng: Optional[torch.Generator] = None, + normalize: bool = True, + unnormalize: bool = True, ) -> tuple[torch.Tensor, DiagonalGaussianDistribution]: posterior = self.encode(x, normalize=normalize) @@ -190,7 +187,7 @@ class Encoder1D(nn.Module): self.attn_layers = attn_layers self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) - in_ch_mult = (1, ) + tuple(ch_mult) + in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult # downsampling self.down = nn.ModuleList() @@ -229,8 +226,8 @@ class Encoder1D(nn.Module): # end self.conv_out = ops.Conv1d(block_in, - 2 * embed_dim if double_z else embed_dim, - kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + 2 * embed_dim if double_z else embed_dim, + kernel_size=kernel_size, padding=kernel_size // 2, bias=False) self.learnable_gain = nn.Parameter(torch.zeros([])) @@ -355,4 +352,3 @@ def get_my_vae(name: str, **kwargs) -> VAE: if name == '44k': return VAE_44k(**kwargs) raise ValueError(f'Unknown model: {name}') - diff --git a/comfy/ldm/mmaudio/vae/vae_modules.py b/comfy/ldm/mmaudio/vae/vae_modules.py index 3ad05134b..7e4663b0b 100644 --- a/comfy/ldm/mmaudio/vae/vae_modules.py +++ b/comfy/ldm/mmaudio/vae/vae_modules.py @@ -1,17 +1,19 @@ import torch import torch.nn as nn import torch.nn.functional as F -from comfy.ldm.modules.diffusionmodules.model import vae_attention +from ...modules.diffusionmodules.model import vae_attention import math -import comfy.ops -ops = comfy.ops.disable_weight_init +from ....ops import disable_weight_init as ops + def nonlinearity(x): # swish return torch.nn.functional.silu(x) / 0.596 + def mp_sum(a, b, t=0.5): - return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2) + return a.lerp(b, t) / math.sqrt((1 - t) ** 2 + t ** 2) + def normalize(x, dim=None, eps=1e-4): if dim is None: @@ -20,6 +22,7 @@ def normalize(x, dim=None, eps=1e-4): norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) return x / norm.to(x.dtype) + class ResnetBlock1D(nn.Module): def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True): diff --git a/comfy_extras/nodes/nodes_eps.py b/comfy_extras/nodes/nodes_eps.py index 4d8061741..6c3abf1b3 100644 --- a/comfy_extras/nodes/nodes_eps.py +++ b/comfy_extras/nodes/nodes_eps.py @@ -14,6 +14,7 @@ class EpsilonScaling(io.ComfyNode): which can significantly improve sample quality. This implementation uses the "uniform schedule" recommended by the paper for its practicality and effectiveness. """ + @classmethod def define_schema(cls): return io.Schema( @@ -66,7 +67,7 @@ class EpsilonScaling(io.ComfyNode): def compute_tsr_rescaling_factor( - snr: torch.Tensor, tsr_k: float, tsr_variance: float + snr: torch.Tensor, tsr_k: float, tsr_variance: float ) -> torch.Tensor: """Compute the rescaling score ratio in Temporal Score Rescaling. @@ -74,7 +75,7 @@ def compute_tsr_rescaling_factor( """ posinf_mask = torch.isposinf(snr) rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1) - return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k + return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k class TemporalScoreRescaling(io.ComfyNode): @@ -125,7 +126,7 @@ class TemporalScoreRescaling(io.ComfyNode): @classmethod def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput: - tsr_variance = tsr_sigma**2 + tsr_variance = tsr_sigma ** 2 def temporal_score_rescaling(args): denoised = args["denoised"]