mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 22:00:19 +08:00
Ensure ema is defined before operating on it
This commit is contained in:
parent
c364e42a11
commit
891154b79e
@ -78,15 +78,18 @@ class VectorQuantize(nn.Module):
|
|||||||
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
|
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
|
||||||
|
|
||||||
def _updateEMA(self, z_e_x, indices):
|
def _updateEMA(self, z_e_x, indices):
|
||||||
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
|
if self.ema_loss:
|
||||||
elem_count = mask.sum(dim=0)
|
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
|
||||||
weight_sum = torch.mm(mask.t(), z_e_x)
|
elem_count = mask.sum(dim=0)
|
||||||
|
weight_sum = torch.mm(mask.t(), z_e_x)
|
||||||
|
|
||||||
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
|
self.register_buffer('ema_element_count', self._laplace_smoothing(
|
||||||
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
|
(self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count),
|
||||||
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
|
1e-5)
|
||||||
|
)
|
||||||
|
self.register_buffer('ema_weight_sum', (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum))
|
||||||
|
|
||||||
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
|
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
|
||||||
|
|
||||||
def idx2vq(self, idx, dim=-1):
|
def idx2vq(self, idx, dim=-1):
|
||||||
q_idx = self.codebook(idx)
|
q_idx = self.codebook(idx)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user