Handle batches correctly

This commit is contained in:
jekky 2023-03-01 22:01:11 +00:00
parent c33a948fa7
commit 4e3e46a323

View File

@ -846,10 +846,14 @@ class ESRGAN:
model_path = self.get_path(model=model),
model = net
)
if face_restore is not None:
return face_restore(image, upsampler, scale)
res, _ = upsampler.enhance(255. * image[0].numpy(), outscale = scale)
return (torch.from_numpy(res.astype(np.float32) / 255.0)[None,],)
outputs = []
for img in image:
if face_restore is not None:
res = face_restore(255. * img.numpy(), upsampler, scale)
else:
res, _ = upsampler.enhance(255. * img.numpy(), outscale = scale)
outputs.append(torch.from_numpy(res.astype(np.float32) / 255.0))
return (outputs,)
def get_net(self, model):
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
@ -905,8 +909,8 @@ class GFPGAN:
channel_multiplier=2,
bg_upsampler=upscaler,
)
_, _, res = enhancer.enhance(255. * image[0].numpy(), paste_back=True, weight=weight)
return (torch.from_numpy(res.astype(np.float32) / 255.0)[None,],)
_, _, res = enhancer.enhance(image, paste_back=True, weight=weight)
return res
class ImageInvert: