From 0a11af628b8f6b81467181a0ee2b7f00909395cf Mon Sep 17 00:00:00 2001 From: Silversith Date: Wed, 19 Apr 2023 10:41:56 +0200 Subject: [PATCH] Attempt to allow for multiple images from the batch --- .gitignore | 1 + custom_nodes/facerestore/__init__.py | 88 +++++++++++++++------------- 2 files changed, 49 insertions(+), 40 deletions(-) diff --git a/.gitignore b/.gitignore index b0ba33f67..2e2616b7f 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ venv38/ .idea/ *.ckpt *.safetensors +models/ diff --git a/custom_nodes/facerestore/__init__.py b/custom_nodes/facerestore/__init__.py index feb59b447..6557306aa 100644 --- a/custom_nodes/facerestore/__init__.py +++ b/custom_nodes/facerestore/__init__.py @@ -114,61 +114,69 @@ class FaceRestoreWithModel: self.face_helper = None def restore_face(self, upscale_model, image, facedetection): - device = model_management.get_torch_device() - upscale_model.to(device) - if self.face_helper is None: - self.face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model=facedetection, save_ext='png', use_parse=True, device=device) + counter = 0 + restored_img_list = [] + for image_itm in image: + device = model_management.get_torch_device() + upscale_model.to(device) + if self.face_helper is None: + self.face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model=facedetection, save_ext='png', use_parse=True, device=device) - image_np = 255. * image.cpu().numpy().squeeze() + image_np = 255. * image_itm.cpu().numpy().squeeze() - image_np = image_np[:, :, ::-1] + image_np = image_np[:, :, ::-1] - original_resolution = image_np.shape[0:2] + original_resolution = image_np.shape[0:2] - if upscale_model is None or self.face_helper is None: - return image + if upscale_model is None or self.face_helper is None: + return image_itm - self.face_helper.clean_all() - self.face_helper.read_image(image_np) - self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) - self.face_helper.align_warp_face() + self.face_helper.clean_all() + self.face_helper.read_image(image_np) + self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) + self.face_helper.align_warp_face() - restored_face = None - for idx, cropped_face in enumerate(self.face_helper.cropped_faces): - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(device) + restored_face = None + for idx, cropped_face in enumerate(self.face_helper.cropped_faces): + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(device) - try: - with torch.no_grad(): - #output = upscale_model(cropped_face_t, w=strength, adain=True)[0] - output = upscale_model(cropped_face_t)[0] - restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) - del output - torch.cuda.empty_cache() - except Exception as error: - print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr) - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + try: + with torch.no_grad(): + #output = upscale_model(cropped_face_t, w=strength, adain=True)[0] + output = upscale_model(cropped_face_t)[0] + restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) + del output + torch.cuda.empty_cache() + except Exception as error: + print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr) + restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) - restored_face = restored_face.astype('uint8') - self.face_helper.add_restored_face(restored_face) + restored_face = restored_face.astype('uint8') + self.face_helper.add_restored_face(restored_face) - self.face_helper.get_inverse_affine(None) + self.face_helper.get_inverse_affine(None) - restored_img = self.face_helper.paste_faces_to_input_image() - restored_img = restored_img[:, :, ::-1] + restored_img = self.face_helper.paste_faces_to_input_image() + restored_img = restored_img[:, :, ::-1] - if original_resolution != restored_img.shape[0:2]: - restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) + if original_resolution != restored_img.shape[0:2]: + restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) - self.face_helper.clean_all() + self.face_helper.clean_all() - # restored_img = cv2.cvtColor(restored_face, cv2.COLOR_BGR2RGB) + # restored_img = cv2.cvtColor(restored_face, cv2.COLOR_BGR2RGB) + counter += 1 - restored_img_np = np.array(restored_img).astype(np.float32) / 255.0 - restored_img_tensor = torch.from_numpy(restored_img_np).unsqueeze(0) + restored_img_np = np.array(restored_img).astype(np.float32) / 255.0 + restored_img_tensor = torch.from_numpy(restored_img_np).unsqueeze(0) + restored_img_list.append(restored_img_tensor) - return (restored_img_tensor,) + for i in range(len(image)): + image[i] = restored_img_list[i] + + return (image,) NODE_CLASS_MAPPINGS = { "FaceRestoreWithModel": FaceRestoreWithModel,