mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 21:30:15 +08:00
Add basic esrgan and gfpgan support
This commit is contained in:
parent
3ddff339f5
commit
c33a948fa7
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,3 +4,4 @@ output/
|
|||||||
models/checkpoints
|
models/checkpoints
|
||||||
models/vae
|
models/vae
|
||||||
models/embeddings
|
models/embeddings
|
||||||
|
gfpgan/
|
||||||
|
|||||||
87
nodes.py
87
nodes.py
@ -823,6 +823,91 @@ class ImageScale:
|
|||||||
s = s.movedim(1,-1)
|
s = s.movedim(1,-1)
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
|
class ESRGAN:
|
||||||
|
models = ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B", "RealESRGAN_x2plus", "realesr-animevideov3", "realesr-general-x4v3"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "image": ("IMAGE",),
|
||||||
|
"model": (s.models,),
|
||||||
|
"scale": ("FLOAT", {"default": 2.0, "min": 2.0, "max": 4.0, "step": 2.0}),
|
||||||
|
"face_restore": ("FACE_RESTORE_MODEL",)
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "upscale"
|
||||||
|
CATEGORY = "image"
|
||||||
|
|
||||||
|
def upscale(self, image, model, scale, face_restore):
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
|
net = self.get_net(model=model)
|
||||||
|
upsampler = RealESRGANer(
|
||||||
|
scale = net.scale if hasattr(net, "scale") else net.upscale,
|
||||||
|
model_path = self.get_path(model=model),
|
||||||
|
model = net
|
||||||
|
)
|
||||||
|
if face_restore is not None:
|
||||||
|
return face_restore(image, upsampler, scale)
|
||||||
|
res, _ = upsampler.enhance(255. * image[0].numpy(), outscale = scale)
|
||||||
|
return (torch.from_numpy(res.astype(np.float32) / 255.0)[None,],)
|
||||||
|
|
||||||
|
def get_net(self, model):
|
||||||
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
|
||||||
|
match model:
|
||||||
|
case "RealESRGAN_x4plus" | "RealESRNet_x4plus":
|
||||||
|
return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||||
|
case "RealESRGAN_x4plus_anime_6B":
|
||||||
|
return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||||
|
case "RealESRGAN_x2plus":
|
||||||
|
return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
||||||
|
case "realesr-animevideov3":
|
||||||
|
return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
||||||
|
case "realesr-general-x4v3":
|
||||||
|
return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||||
|
|
||||||
|
def get_path(self, model):
|
||||||
|
match model:
|
||||||
|
case "RealESRGAN_x4plus":
|
||||||
|
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
|
||||||
|
case "RealESRNet_x4plus":
|
||||||
|
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
|
||||||
|
case "RealESRGAN_x4plus_anime_6B":
|
||||||
|
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
|
||||||
|
case "RealESRGAN_x2plus":
|
||||||
|
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
|
||||||
|
case "realesr-animevideov3":
|
||||||
|
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth"
|
||||||
|
case "realesr-general-x4v3":
|
||||||
|
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
|
||||||
|
|
||||||
|
class GFPGAN:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model_path": ("STRING", {"default": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"}),
|
||||||
|
"weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("FACE_RESTORE_MODEL",)
|
||||||
|
FUNCTION = "create"
|
||||||
|
CATEGORY = "image"
|
||||||
|
|
||||||
|
def create(self, model_path, weight):
|
||||||
|
return (lambda image, upscaler, scale: self.callback(image, upscaler, scale, model_path, weight),)
|
||||||
|
|
||||||
|
def callback(self, image, upscaler, scale, model_path, weight):
|
||||||
|
from gfpgan import GFPGANer
|
||||||
|
|
||||||
|
enhancer = GFPGANer(
|
||||||
|
model_path=model_path,
|
||||||
|
upscale=scale,
|
||||||
|
arch='clean',
|
||||||
|
channel_multiplier=2,
|
||||||
|
bg_upsampler=upscaler,
|
||||||
|
)
|
||||||
|
_, _, res = enhancer.enhance(255. * image[0].numpy(), paste_back=True, weight=weight)
|
||||||
|
return (torch.from_numpy(res.astype(np.float32) / 255.0)[None,],)
|
||||||
|
|
||||||
class ImageInvert:
|
class ImageInvert:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -855,6 +940,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LoadImage": LoadImage,
|
"LoadImage": LoadImage,
|
||||||
"LoadImageMask": LoadImageMask,
|
"LoadImageMask": LoadImageMask,
|
||||||
"ImageScale": ImageScale,
|
"ImageScale": ImageScale,
|
||||||
|
"ESRGAN": ESRGAN,
|
||||||
|
"GFPGAN": GFPGAN,
|
||||||
"ImageInvert": ImageInvert,
|
"ImageInvert": ImageInvert,
|
||||||
"ConditioningCombine": ConditioningCombine,
|
"ConditioningCombine": ConditioningCombine,
|
||||||
"ConditioningSetArea": ConditioningSetArea,
|
"ConditioningSetArea": ConditioningSetArea,
|
||||||
|
|||||||
@ -9,3 +9,5 @@ safetensors
|
|||||||
pytorch_lightning
|
pytorch_lightning
|
||||||
aiohttp
|
aiohttp
|
||||||
accelerate
|
accelerate
|
||||||
|
realesrgan
|
||||||
|
gfpgan
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user