From 22191f7faab3af1e09ad44f5e73a111b471ed182 Mon Sep 17 00:00:00 2001 From: jekky <11986158+jac3km4@users.noreply.github.com> Date: Sat, 4 Mar 2023 13:45:43 +0000 Subject: [PATCH] Load models locally and add a configure script --- .gitignore | 3 ++- configure.py | 33 +++++++++++++++++++++++++++++++++ nodes.py | 38 ++++++++++++++------------------------ requirements.txt | 1 + 4 files changed, 50 insertions(+), 25 deletions(-) create mode 100644 configure.py diff --git a/.gitignore b/.gitignore index d0c92c909..df0318285 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ output/ models/checkpoints models/vae models/embeddings -gfpgan/ +models/gfpgan +models/realesrgan diff --git a/configure.py b/configure.py new file mode 100644 index 000000000..d9acd188e --- /dev/null +++ b/configure.py @@ -0,0 +1,33 @@ +import cmd, requests, os + +class ComfyConfigure(cmd.Cmd): + intro = "Welcome to ComfyUI configure shell. Type help or ? to list commands.\n" + prompt = "(configure) " + file = None + + def do_install_esrgan_deps(self, arg): + 'Install base ESRGAN/GFPGAN model dependencies' + self.install_model('realesrgan', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth') + self.install_model('realesrgan', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth') + self.install_model('gfpgan', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth') + print('done!') + + def do_exit(self, arg): + 'Exit the shell' + return True + + def install_model(self, category, url): + models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", category) + if not os.path.isdir(models_dir): + os.mkdir(models_dir) + + print('downloading {0}...'.format(url)) + resp = requests.get(url) + if resp: + with open(os.path.join(models_dir, os.path.basename(url)), "wb") as file: + file.write(resp.content) + else: + print('failed to download {0}: {1}', url, resp.text) + +if __name__ == '__main__': + ComfyConfigure().cmdloop() diff --git a/nodes.py b/nodes.py index 940538d78..41c1de5f2 100644 --- a/nodes.py +++ b/nodes.py @@ -824,12 +824,12 @@ class ImageScale: return (s,) class ESRGAN: - models = ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B", "RealESRGAN_x2plus", "realesr-animevideov3", "realesr-general-x4v3"] + models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "realesrgan") @classmethod def INPUT_TYPES(s): return {"required": { "image": ("IMAGE",), - "model": (s.models,), + "model": (filter_files_extensions(recursive_search(s.models_dir), '.pth'), ), "scale": ("FLOAT", {"default": 2.0, "min": 2.0, "max": 4.0, "step": 2.0}), "face_restore": ("FACE_RESTORE_MODEL",) }} @@ -843,7 +843,7 @@ class ESRGAN: net = self.get_net(model=model) upsampler = RealESRGANer( scale = net.scale if hasattr(net, "scale") else net.upscale, - model_path = self.get_path(model=model), + model_path = os.path.join(self.models_dir, model), model = net ) outputs = [] @@ -859,7 +859,7 @@ class ESRGAN: from realesrgan.archs.srvgg_arch import SRVGGNetCompact from basicsr.archs.rrdbnet_arch import RRDBNet - match model: + match os.path.splitext(model)[0]: case "RealESRGAN_x4plus" | "RealESRNet_x4plus": return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) case "RealESRGAN_x4plus_anime_6B": @@ -870,40 +870,30 @@ class ESRGAN: return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') case "realesr-general-x4v3": return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') - - def get_path(self, model): - match model: - case "RealESRGAN_x4plus": - return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth" - case "RealESRNet_x4plus": - return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth" - case "RealESRGAN_x4plus_anime_6B": - return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth" - case "RealESRGAN_x2plus": - return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth" - case "realesr-animevideov3": - return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth" - case "realesr-general-x4v3": - return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth" + case other: + print('Unknown model {0}, defaulting to RRDBNET...'.format(other)) + return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) class GFPGAN: + models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "gfpgan") + @classmethod def INPUT_TYPES(s): - return {"required": { "model_path": ("STRING", {"default": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"}), + return {"required": { "model": (filter_files_extensions(recursive_search(s.models_dir), '.pth'), ), "weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}) }} RETURN_TYPES = ("FACE_RESTORE_MODEL",) FUNCTION = "create" CATEGORY = "image" - def create(self, model_path, weight): - return (lambda image, upscaler, scale: self.callback(image, upscaler, scale, model_path, weight),) + def create(self, model, weight): + return (lambda image, upscaler, scale: self.callback(image, upscaler, scale, model, weight),) - def callback(self, image, upscaler, scale, model_path, weight): + def callback(self, image, upscaler, scale, model, weight): from gfpgan import GFPGANer enhancer = GFPGANer( - model_path=model_path, + model_path=os.path.join(self.models_dir, model), upscale=scale, arch='clean', channel_multiplier=2, diff --git a/requirements.txt b/requirements.txt index 017e85286..f1ea393ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ aiohttp accelerate realesrgan gfpgan +requests