From ee44e3b1d7ac81e13104f2c6c374e71a9e947501 Mon Sep 17 00:00:00 2001 From: Max Tretikov Date: Fri, 14 Jun 2024 13:40:10 -0600 Subject: [PATCH] Fix errors in comfy.extra_samplers.uni_pc --- comfy/extra_samplers/uni_pc.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index a30d1d03f..157f21acf 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -357,6 +357,7 @@ class UniPC: predict_x0=True, thresholding=False, max_val=1., + dynamic_thresholding_ratio=0.995, variant='bh1', ): """Construct a UniPC. @@ -369,6 +370,7 @@ class UniPC: self.predict_x0 = predict_x0 self.thresholding = thresholding self.max_val = max_val + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio def dynamic_thresholding_fn(self, x0, t=None): """ @@ -377,7 +379,7 @@ class UniPC: dims = x0.dim() p = self.dynamic_thresholding_ratio s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) - s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) x0 = torch.clamp(x0, -s, s) / s return x0 @@ -634,16 +636,18 @@ class UniPC: # now predictor use_predictor = len(D1s) > 0 and x_t is None + if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) # (B, K) - if x_t is None: - # for order 2, we use a simplified version - if order == 2: - rhos_p = torch.tensor([0.5], device=b.device) - else: - rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) else: D1s = None + + if use_predictor: + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], device=b.device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) if use_corrector: # print('using corrector')