LTX audio vae novram fixes. (#12796)

This commit is contained in:
comfyanonymous 2026-03-05 13:31:28 -08:00 committed by GitHub
parent 8befce5c7b
commit 17b43c2b87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -82,7 +82,7 @@ class LowPassFilter1d(nn.Module):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
return F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
class UpSample1d(nn.Module):
@ -191,7 +191,7 @@ class Snake(nn.Module):
self.eps = 1e-9
def forward(self, x):
a = self.alpha.unsqueeze(0).unsqueeze(-1)
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
if self.alpha_logscale:
a = torch.exp(a)
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
@ -218,8 +218,8 @@ class SnakeBeta(nn.Module):
self.eps = 1e-9
def forward(self, x):
a = self.alpha.unsqueeze(0).unsqueeze(-1)
b = self.beta.unsqueeze(0).unsqueeze(-1)
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
b = comfy.model_management.cast_to(self.beta.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
if self.alpha_logscale:
a = torch.exp(a)
b = torch.exp(b)
@ -597,7 +597,7 @@ class _STFTFn(nn.Module):
y = y.unsqueeze(1) # (B, 1, T)
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
y = F.pad(y, (left_pad, 0))
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
spec = F.conv1d(y, comfy.model_management.cast_to(self.forward_basis, dtype=y.dtype, device=y.device), stride=self.hop_length, padding=0)
n_freqs = spec.shape[1] // 2
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
magnitude = torch.sqrt(real ** 2 + imag ** 2)
@ -648,7 +648,7 @@ class MelSTFT(nn.Module):
"""
magnitude, phase = self.stft_fn(y)
energy = torch.norm(magnitude, dim=1)
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
mel = torch.matmul(comfy.model_management.cast_to(self.mel_basis, dtype=magnitude.dtype, device=y.device), magnitude)
log_mel = torch.log(torch.clamp(mel, min=1e-5))
return log_mel, magnitude, phase, energy