mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Fix errors in comfy.extra_samplers.uni_pc
This commit is contained in:
parent
9ad840e614
commit
ee44e3b1d7
@ -357,6 +357,7 @@ class UniPC:
|
||||
predict_x0=True,
|
||||
thresholding=False,
|
||||
max_val=1.,
|
||||
dynamic_thresholding_ratio=0.995,
|
||||
variant='bh1',
|
||||
):
|
||||
"""Construct a UniPC.
|
||||
@ -369,6 +370,7 @@ class UniPC:
|
||||
self.predict_x0 = predict_x0
|
||||
self.thresholding = thresholding
|
||||
self.max_val = max_val
|
||||
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
|
||||
|
||||
def dynamic_thresholding_fn(self, x0, t=None):
|
||||
"""
|
||||
@ -377,7 +379,7 @@ class UniPC:
|
||||
dims = x0.dim()
|
||||
p = self.dynamic_thresholding_ratio
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
||||
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
||||
x0 = torch.clamp(x0, -s, s) / s
|
||||
return x0
|
||||
|
||||
@ -634,16 +636,18 @@ class UniPC:
|
||||
|
||||
# now predictor
|
||||
use_predictor = len(D1s) > 0 and x_t is None
|
||||
|
||||
if len(D1s) > 0:
|
||||
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||
if x_t is None:
|
||||
# for order 2, we use a simplified version
|
||||
if order == 2:
|
||||
rhos_p = torch.tensor([0.5], device=b.device)
|
||||
else:
|
||||
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
||||
else:
|
||||
D1s = None
|
||||
|
||||
if use_predictor:
|
||||
# for order 2, we use a simplified version
|
||||
if order == 2:
|
||||
rhos_p = torch.tensor([0.5], device=b.device)
|
||||
else:
|
||||
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
||||
|
||||
if use_corrector:
|
||||
# print('using corrector')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user