mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-08 02:37:42 +08:00
LTX audio vae novram fixes. (#12796)
This commit is contained in:
parent
8befce5c7b
commit
17b43c2b87
@ -82,7 +82,7 @@ class LowPassFilter1d(nn.Module):
|
|||||||
_, C, _ = x.shape
|
_, C, _ = x.shape
|
||||||
if self.padding:
|
if self.padding:
|
||||||
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||||
return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
return F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
|
||||||
|
|
||||||
|
|
||||||
class UpSample1d(nn.Module):
|
class UpSample1d(nn.Module):
|
||||||
@ -191,7 +191,7 @@ class Snake(nn.Module):
|
|||||||
self.eps = 1e-9
|
self.eps = 1e-9
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
|
||||||
if self.alpha_logscale:
|
if self.alpha_logscale:
|
||||||
a = torch.exp(a)
|
a = torch.exp(a)
|
||||||
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
|
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
|
||||||
@ -218,8 +218,8 @@ class SnakeBeta(nn.Module):
|
|||||||
self.eps = 1e-9
|
self.eps = 1e-9
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
|
||||||
b = self.beta.unsqueeze(0).unsqueeze(-1)
|
b = comfy.model_management.cast_to(self.beta.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
|
||||||
if self.alpha_logscale:
|
if self.alpha_logscale:
|
||||||
a = torch.exp(a)
|
a = torch.exp(a)
|
||||||
b = torch.exp(b)
|
b = torch.exp(b)
|
||||||
@ -597,7 +597,7 @@ class _STFTFn(nn.Module):
|
|||||||
y = y.unsqueeze(1) # (B, 1, T)
|
y = y.unsqueeze(1) # (B, 1, T)
|
||||||
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
||||||
y = F.pad(y, (left_pad, 0))
|
y = F.pad(y, (left_pad, 0))
|
||||||
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
|
spec = F.conv1d(y, comfy.model_management.cast_to(self.forward_basis, dtype=y.dtype, device=y.device), stride=self.hop_length, padding=0)
|
||||||
n_freqs = spec.shape[1] // 2
|
n_freqs = spec.shape[1] // 2
|
||||||
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||||
magnitude = torch.sqrt(real ** 2 + imag ** 2)
|
magnitude = torch.sqrt(real ** 2 + imag ** 2)
|
||||||
@ -648,7 +648,7 @@ class MelSTFT(nn.Module):
|
|||||||
"""
|
"""
|
||||||
magnitude, phase = self.stft_fn(y)
|
magnitude, phase = self.stft_fn(y)
|
||||||
energy = torch.norm(magnitude, dim=1)
|
energy = torch.norm(magnitude, dim=1)
|
||||||
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
mel = torch.matmul(comfy.model_management.cast_to(self.mel_basis, dtype=magnitude.dtype, device=y.device), magnitude)
|
||||||
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||||
return log_mel, magnitude, phase, energy
|
return log_mel, magnitude, phase, energy
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user