mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 10:17:31 +08:00
LTX audio vae novram fixes. (#12796)
This commit is contained in:
parent
8befce5c7b
commit
17b43c2b87
@ -82,7 +82,7 @@ class LowPassFilter1d(nn.Module):
|
||||
_, C, _ = x.shape
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||
return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||
return F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
@ -191,7 +191,7 @@ class Snake(nn.Module):
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x):
|
||||
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
|
||||
if self.alpha_logscale:
|
||||
a = torch.exp(a)
|
||||
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
|
||||
@ -218,8 +218,8 @@ class SnakeBeta(nn.Module):
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x):
|
||||
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
b = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
|
||||
b = comfy.model_management.cast_to(self.beta.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
|
||||
if self.alpha_logscale:
|
||||
a = torch.exp(a)
|
||||
b = torch.exp(b)
|
||||
@ -597,7 +597,7 @@ class _STFTFn(nn.Module):
|
||||
y = y.unsqueeze(1) # (B, 1, T)
|
||||
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
||||
y = F.pad(y, (left_pad, 0))
|
||||
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
|
||||
spec = F.conv1d(y, comfy.model_management.cast_to(self.forward_basis, dtype=y.dtype, device=y.device), stride=self.hop_length, padding=0)
|
||||
n_freqs = spec.shape[1] // 2
|
||||
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||
magnitude = torch.sqrt(real ** 2 + imag ** 2)
|
||||
@ -648,7 +648,7 @@ class MelSTFT(nn.Module):
|
||||
"""
|
||||
magnitude, phase = self.stft_fn(y)
|
||||
energy = torch.norm(magnitude, dim=1)
|
||||
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
||||
mel = torch.matmul(comfy.model_management.cast_to(self.mel_basis, dtype=magnitude.dtype, device=y.device), magnitude)
|
||||
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return log_mel, magnitude, phase, energy
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user