From 4e3e46a3238569728702f5f8a9f4c4100baee57d Mon Sep 17 00:00:00 2001 From: jekky <11986158+jac3km4@users.noreply.github.com> Date: Wed, 1 Mar 2023 22:01:11 +0000 Subject: [PATCH] Handle batches correctly --- nodes.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/nodes.py b/nodes.py index da08f852b..940538d78 100644 --- a/nodes.py +++ b/nodes.py @@ -846,10 +846,14 @@ class ESRGAN: model_path = self.get_path(model=model), model = net ) - if face_restore is not None: - return face_restore(image, upsampler, scale) - res, _ = upsampler.enhance(255. * image[0].numpy(), outscale = scale) - return (torch.from_numpy(res.astype(np.float32) / 255.0)[None,],) + outputs = [] + for img in image: + if face_restore is not None: + res = face_restore(255. * img.numpy(), upsampler, scale) + else: + res, _ = upsampler.enhance(255. * img.numpy(), outscale = scale) + outputs.append(torch.from_numpy(res.astype(np.float32) / 255.0)) + return (outputs,) def get_net(self, model): from realesrgan.archs.srvgg_arch import SRVGGNetCompact @@ -905,8 +909,8 @@ class GFPGAN: channel_multiplier=2, bg_upsampler=upscaler, ) - _, _, res = enhancer.enhance(255. * image[0].numpy(), paste_back=True, weight=weight) - return (torch.from_numpy(res.astype(np.float32) / 255.0)[None,],) + _, _, res = enhancer.enhance(image, paste_back=True, weight=weight) + return res class ImageInvert: