mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Ensure ema is defined before operating on it
This commit is contained in:
parent
c364e42a11
commit
891154b79e
@ -78,15 +78,18 @@ class VectorQuantize(nn.Module):
|
||||
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
|
||||
|
||||
def _updateEMA(self, z_e_x, indices):
|
||||
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
|
||||
elem_count = mask.sum(dim=0)
|
||||
weight_sum = torch.mm(mask.t(), z_e_x)
|
||||
if self.ema_loss:
|
||||
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
|
||||
elem_count = mask.sum(dim=0)
|
||||
weight_sum = torch.mm(mask.t(), z_e_x)
|
||||
|
||||
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
|
||||
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
|
||||
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
|
||||
self.register_buffer('ema_element_count', self._laplace_smoothing(
|
||||
(self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count),
|
||||
1e-5)
|
||||
)
|
||||
self.register_buffer('ema_weight_sum', (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum))
|
||||
|
||||
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
|
||||
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
|
||||
|
||||
def idx2vq(self, idx, dim=-1):
|
||||
q_idx = self.codebook(idx)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user