Attempt to allow for multiple images from the batch

This commit is contained in:
Silversith 2023-04-19 10:41:56 +02:00
parent 946aa1e4f5
commit 0a11af628b
2 changed files with 49 additions and 40 deletions

1
.gitignore vendored
View File

@ -12,3 +12,4 @@ venv38/
.idea/
*.ckpt
*.safetensors
models/

View File

@ -114,61 +114,69 @@ class FaceRestoreWithModel:
self.face_helper = None
def restore_face(self, upscale_model, image, facedetection):
device = model_management.get_torch_device()
upscale_model.to(device)
if self.face_helper is None:
self.face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model=facedetection, save_ext='png', use_parse=True, device=device)
counter = 0
restored_img_list = []
for image_itm in image:
device = model_management.get_torch_device()
upscale_model.to(device)
if self.face_helper is None:
self.face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model=facedetection, save_ext='png', use_parse=True, device=device)
image_np = 255. * image.cpu().numpy().squeeze()
image_np = 255. * image_itm.cpu().numpy().squeeze()
image_np = image_np[:, :, ::-1]
image_np = image_np[:, :, ::-1]
original_resolution = image_np.shape[0:2]
original_resolution = image_np.shape[0:2]
if upscale_model is None or self.face_helper is None:
return image
if upscale_model is None or self.face_helper is None:
return image_itm
self.face_helper.clean_all()
self.face_helper.read_image(image_np)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
self.face_helper.align_warp_face()
self.face_helper.clean_all()
self.face_helper.read_image(image_np)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
self.face_helper.align_warp_face()
restored_face = None
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
restored_face = None
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
try:
with torch.no_grad():
#output = upscale_model(cropped_face_t, w=strength, adain=True)[0]
output = upscale_model(cropped_face_t)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
except Exception as error:
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
try:
with torch.no_grad():
#output = upscale_model(cropped_face_t, w=strength, adain=True)[0]
output = upscale_model(cropped_face_t)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
except Exception as error:
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
self.face_helper.add_restored_face(restored_face)
restored_face = restored_face.astype('uint8')
self.face_helper.add_restored_face(restored_face)
self.face_helper.get_inverse_affine(None)
self.face_helper.get_inverse_affine(None)
restored_img = self.face_helper.paste_faces_to_input_image()
restored_img = restored_img[:, :, ::-1]
restored_img = self.face_helper.paste_faces_to_input_image()
restored_img = restored_img[:, :, ::-1]
if original_resolution != restored_img.shape[0:2]:
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
if original_resolution != restored_img.shape[0:2]:
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
self.face_helper.clean_all()
self.face_helper.clean_all()
# restored_img = cv2.cvtColor(restored_face, cv2.COLOR_BGR2RGB)
# restored_img = cv2.cvtColor(restored_face, cv2.COLOR_BGR2RGB)
counter += 1
restored_img_np = np.array(restored_img).astype(np.float32) / 255.0
restored_img_tensor = torch.from_numpy(restored_img_np).unsqueeze(0)
restored_img_np = np.array(restored_img).astype(np.float32) / 255.0
restored_img_tensor = torch.from_numpy(restored_img_np).unsqueeze(0)
restored_img_list.append(restored_img_tensor)
return (restored_img_tensor,)
for i in range(len(image)):
image[i] = restored_img_list[i]
return (image,)
NODE_CLASS_MAPPINGS = {
"FaceRestoreWithModel": FaceRestoreWithModel,