diff --git a/.gitignore b/.gitignore index 7961356d8..d0c92c909 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ output/ models/checkpoints models/vae models/embeddings +gfpgan/ diff --git a/nodes.py b/nodes.py index 26dad5729..da08f852b 100644 --- a/nodes.py +++ b/nodes.py @@ -823,6 +823,91 @@ class ImageScale: s = s.movedim(1,-1) return (s,) +class ESRGAN: + models = ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B", "RealESRGAN_x2plus", "realesr-animevideov3", "realesr-general-x4v3"] + + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "model": (s.models,), + "scale": ("FLOAT", {"default": 2.0, "min": 2.0, "max": 4.0, "step": 2.0}), + "face_restore": ("FACE_RESTORE_MODEL",) + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "upscale" + CATEGORY = "image" + + def upscale(self, image, model, scale, face_restore): + from realesrgan import RealESRGANer + + net = self.get_net(model=model) + upsampler = RealESRGANer( + scale = net.scale if hasattr(net, "scale") else net.upscale, + model_path = self.get_path(model=model), + model = net + ) + if face_restore is not None: + return face_restore(image, upsampler, scale) + res, _ = upsampler.enhance(255. * image[0].numpy(), outscale = scale) + return (torch.from_numpy(res.astype(np.float32) / 255.0)[None,],) + + def get_net(self, model): + from realesrgan.archs.srvgg_arch import SRVGGNetCompact + from basicsr.archs.rrdbnet_arch import RRDBNet + + match model: + case "RealESRGAN_x4plus" | "RealESRNet_x4plus": + return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + case "RealESRGAN_x4plus_anime_6B": + return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + case "RealESRGAN_x2plus": + return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) + case "realesr-animevideov3": + return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') + case "realesr-general-x4v3": + return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') + + def get_path(self, model): + match model: + case "RealESRGAN_x4plus": + return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth" + case "RealESRNet_x4plus": + return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth" + case "RealESRGAN_x4plus_anime_6B": + return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth" + case "RealESRGAN_x2plus": + return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth" + case "realesr-animevideov3": + return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth" + case "realesr-general-x4v3": + return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth" + +class GFPGAN: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model_path": ("STRING", {"default": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"}), + "weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}) + }} + RETURN_TYPES = ("FACE_RESTORE_MODEL",) + FUNCTION = "create" + CATEGORY = "image" + + def create(self, model_path, weight): + return (lambda image, upscaler, scale: self.callback(image, upscaler, scale, model_path, weight),) + + def callback(self, image, upscaler, scale, model_path, weight): + from gfpgan import GFPGANer + + enhancer = GFPGANer( + model_path=model_path, + upscale=scale, + arch='clean', + channel_multiplier=2, + bg_upsampler=upscaler, + ) + _, _, res = enhancer.enhance(255. * image[0].numpy(), paste_back=True, weight=weight) + return (torch.from_numpy(res.astype(np.float32) / 255.0)[None,],) + class ImageInvert: @classmethod @@ -855,6 +940,8 @@ NODE_CLASS_MAPPINGS = { "LoadImage": LoadImage, "LoadImageMask": LoadImageMask, "ImageScale": ImageScale, + "ESRGAN": ESRGAN, + "GFPGAN": GFPGAN, "ImageInvert": ImageInvert, "ConditioningCombine": ConditioningCombine, "ConditioningSetArea": ConditioningSetArea, diff --git a/requirements.txt b/requirements.txt index 45f2599d9..017e85286 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,5 @@ safetensors pytorch_lightning aiohttp accelerate +realesrgan +gfpgan