Add basic esrgan and gfpgan support

This commit is contained in:
jekky 2023-03-01 21:48:05 +00:00
parent 3ddff339f5
commit c33a948fa7
3 changed files with 90 additions and 0 deletions

1
.gitignore vendored
View File

@ -4,3 +4,4 @@ output/
models/checkpoints
models/vae
models/embeddings
gfpgan/

View File

@ -823,6 +823,91 @@ class ImageScale:
s = s.movedim(1,-1)
return (s,)
class ESRGAN:
models = ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B", "RealESRGAN_x2plus", "realesr-animevideov3", "realesr-general-x4v3"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"model": (s.models,),
"scale": ("FLOAT", {"default": 2.0, "min": 2.0, "max": 4.0, "step": 2.0}),
"face_restore": ("FACE_RESTORE_MODEL",)
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
CATEGORY = "image"
def upscale(self, image, model, scale, face_restore):
from realesrgan import RealESRGANer
net = self.get_net(model=model)
upsampler = RealESRGANer(
scale = net.scale if hasattr(net, "scale") else net.upscale,
model_path = self.get_path(model=model),
model = net
)
if face_restore is not None:
return face_restore(image, upsampler, scale)
res, _ = upsampler.enhance(255. * image[0].numpy(), outscale = scale)
return (torch.from_numpy(res.astype(np.float32) / 255.0)[None,],)
def get_net(self, model):
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from basicsr.archs.rrdbnet_arch import RRDBNet
match model:
case "RealESRGAN_x4plus" | "RealESRNet_x4plus":
return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
case "RealESRGAN_x4plus_anime_6B":
return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
case "RealESRGAN_x2plus":
return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
case "realesr-animevideov3":
return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
case "realesr-general-x4v3":
return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
def get_path(self, model):
match model:
case "RealESRGAN_x4plus":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
case "RealESRNet_x4plus":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
case "RealESRGAN_x4plus_anime_6B":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
case "RealESRGAN_x2plus":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
case "realesr-animevideov3":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth"
case "realesr-general-x4v3":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
class GFPGAN:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model_path": ("STRING", {"default": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"}),
"weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01})
}}
RETURN_TYPES = ("FACE_RESTORE_MODEL",)
FUNCTION = "create"
CATEGORY = "image"
def create(self, model_path, weight):
return (lambda image, upscaler, scale: self.callback(image, upscaler, scale, model_path, weight),)
def callback(self, image, upscaler, scale, model_path, weight):
from gfpgan import GFPGANer
enhancer = GFPGANer(
model_path=model_path,
upscale=scale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upscaler,
)
_, _, res = enhancer.enhance(255. * image[0].numpy(), paste_back=True, weight=weight)
return (torch.from_numpy(res.astype(np.float32) / 255.0)[None,],)
class ImageInvert:
@classmethod
@ -855,6 +940,8 @@ NODE_CLASS_MAPPINGS = {
"LoadImage": LoadImage,
"LoadImageMask": LoadImageMask,
"ImageScale": ImageScale,
"ESRGAN": ESRGAN,
"GFPGAN": GFPGAN,
"ImageInvert": ImageInvert,
"ConditioningCombine": ConditioningCombine,
"ConditioningSetArea": ConditioningSetArea,

View File

@ -9,3 +9,5 @@ safetensors
pytorch_lightning
aiohttp
accelerate
realesrgan
gfpgan