diff --git a/comfy/ldm/modules/ema.py b/comfy/ldm/modules/ema.py index bded25019..e944d0a33 100644 --- a/comfy/ldm/modules/ema.py +++ b/comfy/ldm/modules/ema.py @@ -23,14 +23,13 @@ class LitEma(nn.Module): self.collected_params = [] def reset_num_updates(self): - del self.num_updates self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) def forward(self, model): decay = self.decay if self.num_updates >= 0: - self.num_updates += 1 + self.register_buffer('num_updates', torch.tensor(1 + self.num_updates, dtype=torch.int)) decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) one_minus_decay = 1.0 - decay