diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index ca8867eaf..1d1c23988 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -78,15 +78,18 @@ class VectorQuantize(nn.Module): return ((x + epsilon) / (n + x.size(0) * epsilon) * n) def _updateEMA(self, z_e_x, indices): - mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float() - elem_count = mask.sum(dim=0) - weight_sum = torch.mm(mask.t(), z_e_x) + if self.ema_loss: + mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float() + elem_count = mask.sum(dim=0) + weight_sum = torch.mm(mask.t(), z_e_x) - self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count) - self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5) - self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum) + self.register_buffer('ema_element_count', self._laplace_smoothing( + (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count), + 1e-5) + ) + self.register_buffer('ema_weight_sum', (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)) - self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) + self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) def idx2vq(self, idx, dim=-1): q_idx = self.codebook(idx)