Fix errors in comfy.extra_samplers.uni_pc

This commit is contained in:
Max Tretikov 2024-06-14 13:40:10 -06:00
parent 9ad840e614
commit ee44e3b1d7

View File

@ -357,6 +357,7 @@ class UniPC:
predict_x0=True,
thresholding=False,
max_val=1.,
dynamic_thresholding_ratio=0.995,
variant='bh1',
):
"""Construct a UniPC.
@ -369,6 +370,7 @@ class UniPC:
self.predict_x0 = predict_x0
self.thresholding = thresholding
self.max_val = max_val
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
def dynamic_thresholding_fn(self, x0, t=None):
"""
@ -377,7 +379,7 @@ class UniPC:
dims = x0.dim()
p = self.dynamic_thresholding_ratio
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
return x0
@ -634,16 +636,18 @@ class UniPC:
# now predictor
use_predictor = len(D1s) > 0 and x_t is None
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
if x_t is None:
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], device=b.device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
else:
D1s = None
if use_predictor:
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], device=b.device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
if use_corrector:
# print('using corrector')