Fix linting errors, use register_buffer

This commit is contained in:
doctorpangloss 2025-10-22 12:16:09 -07:00
parent 427c21309c
commit 674b69c291
2 changed files with 8 additions and 7 deletions

View File

@ -75,14 +75,14 @@ class VAE(nn.Module):
super().__init__()
if data_dim == 80:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
data_mean = torch.tensor(DATA_MEAN_80D, dtype=torch.float32)
data_std = torch.tensor(DATA_STD_80D, dtype=torch.float32)
elif data_dim == 128:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
data_mean = torch.tensor(DATA_MEAN_128D, dtype=torch.float32)
data_std = torch.tensor(DATA_STD_128D, dtype=torch.float32)
self.data_mean = self.data_mean.view(1, -1, 1)
self.data_std = self.data_std.view(1, -1, 1)
self.register_buffer('data_mean', data_mean.view(1, -1, 1))
self.register_buffer('data_std', data_std.view(1, -1, 1))
self.encoder = Encoder1D(
dim=hidden_dim,

View File

@ -31,7 +31,8 @@ logger = logging.getLogger(__name__)
def run_every_op():
if torch.compiler.is_compiling():
# this is available on torch 2.3
if torch.compiler.is_compiling(): # pylint: disable=no-member
return
throw_exception_if_processing_interrupted()