Ensure ema is defined before operating on it

This commit is contained in:
Max Tretikov 2024-06-14 14:37:50 -06:00
parent c364e42a11
commit 891154b79e

View File

@ -78,15 +78,18 @@ class VectorQuantize(nn.Module):
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
def _updateEMA(self, z_e_x, indices):
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
elem_count = mask.sum(dim=0)
weight_sum = torch.mm(mask.t(), z_e_x)
if self.ema_loss:
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
elem_count = mask.sum(dim=0)
weight_sum = torch.mm(mask.t(), z_e_x)
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
self.register_buffer('ema_element_count', self._laplace_smoothing(
(self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count),
1e-5)
)
self.register_buffer('ema_weight_sum', (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum))
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
def idx2vq(self, idx, dim=-1):
q_idx = self.codebook(idx)