mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Add basic esrgan and gfpgan support
This commit is contained in:
parent
3ddff339f5
commit
c33a948fa7
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,3 +4,4 @@ output/
|
||||
models/checkpoints
|
||||
models/vae
|
||||
models/embeddings
|
||||
gfpgan/
|
||||
|
||||
87
nodes.py
87
nodes.py
@ -823,6 +823,91 @@ class ImageScale:
|
||||
s = s.movedim(1,-1)
|
||||
return (s,)
|
||||
|
||||
class ESRGAN:
|
||||
models = ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B", "RealESRGAN_x2plus", "realesr-animevideov3", "realesr-general-x4v3"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"model": (s.models,),
|
||||
"scale": ("FLOAT", {"default": 2.0, "min": 2.0, "max": 4.0, "step": 2.0}),
|
||||
"face_restore": ("FACE_RESTORE_MODEL",)
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "upscale"
|
||||
CATEGORY = "image"
|
||||
|
||||
def upscale(self, image, model, scale, face_restore):
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
net = self.get_net(model=model)
|
||||
upsampler = RealESRGANer(
|
||||
scale = net.scale if hasattr(net, "scale") else net.upscale,
|
||||
model_path = self.get_path(model=model),
|
||||
model = net
|
||||
)
|
||||
if face_restore is not None:
|
||||
return face_restore(image, upsampler, scale)
|
||||
res, _ = upsampler.enhance(255. * image[0].numpy(), outscale = scale)
|
||||
return (torch.from_numpy(res.astype(np.float32) / 255.0)[None,],)
|
||||
|
||||
def get_net(self, model):
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
|
||||
match model:
|
||||
case "RealESRGAN_x4plus" | "RealESRNet_x4plus":
|
||||
return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||
case "RealESRGAN_x4plus_anime_6B":
|
||||
return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||
case "RealESRGAN_x2plus":
|
||||
return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
||||
case "realesr-animevideov3":
|
||||
return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
||||
case "realesr-general-x4v3":
|
||||
return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
|
||||
def get_path(self, model):
|
||||
match model:
|
||||
case "RealESRGAN_x4plus":
|
||||
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
|
||||
case "RealESRNet_x4plus":
|
||||
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
|
||||
case "RealESRGAN_x4plus_anime_6B":
|
||||
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
|
||||
case "RealESRGAN_x2plus":
|
||||
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
|
||||
case "realesr-animevideov3":
|
||||
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth"
|
||||
case "realesr-general-x4v3":
|
||||
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
|
||||
|
||||
class GFPGAN:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model_path": ("STRING", {"default": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"}),
|
||||
"weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||
}}
|
||||
RETURN_TYPES = ("FACE_RESTORE_MODEL",)
|
||||
FUNCTION = "create"
|
||||
CATEGORY = "image"
|
||||
|
||||
def create(self, model_path, weight):
|
||||
return (lambda image, upscaler, scale: self.callback(image, upscaler, scale, model_path, weight),)
|
||||
|
||||
def callback(self, image, upscaler, scale, model_path, weight):
|
||||
from gfpgan import GFPGANer
|
||||
|
||||
enhancer = GFPGANer(
|
||||
model_path=model_path,
|
||||
upscale=scale,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=upscaler,
|
||||
)
|
||||
_, _, res = enhancer.enhance(255. * image[0].numpy(), paste_back=True, weight=weight)
|
||||
return (torch.from_numpy(res.astype(np.float32) / 255.0)[None,],)
|
||||
|
||||
class ImageInvert:
|
||||
|
||||
@classmethod
|
||||
@ -855,6 +940,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"LoadImage": LoadImage,
|
||||
"LoadImageMask": LoadImageMask,
|
||||
"ImageScale": ImageScale,
|
||||
"ESRGAN": ESRGAN,
|
||||
"GFPGAN": GFPGAN,
|
||||
"ImageInvert": ImageInvert,
|
||||
"ConditioningCombine": ConditioningCombine,
|
||||
"ConditioningSetArea": ConditioningSetArea,
|
||||
|
||||
@ -9,3 +9,5 @@ safetensors
|
||||
pytorch_lightning
|
||||
aiohttp
|
||||
accelerate
|
||||
realesrgan
|
||||
gfpgan
|
||||
|
||||
Loading…
Reference in New Issue
Block a user