Load models locally and add a configure script

This commit is contained in:
jekky 2023-03-04 13:45:43 +00:00
parent 4e3e46a323
commit 22191f7faa
4 changed files with 50 additions and 25 deletions

3
.gitignore vendored
View File

@ -4,4 +4,5 @@ output/
models/checkpoints
models/vae
models/embeddings
gfpgan/
models/gfpgan
models/realesrgan

33
configure.py Normal file
View File

@ -0,0 +1,33 @@
import cmd, requests, os
class ComfyConfigure(cmd.Cmd):
intro = "Welcome to ComfyUI configure shell. Type help or ? to list commands.\n"
prompt = "(configure) "
file = None
def do_install_esrgan_deps(self, arg):
'Install base ESRGAN/GFPGAN model dependencies'
self.install_model('realesrgan', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth')
self.install_model('realesrgan', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth')
self.install_model('gfpgan', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth')
print('done!')
def do_exit(self, arg):
'Exit the shell'
return True
def install_model(self, category, url):
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", category)
if not os.path.isdir(models_dir):
os.mkdir(models_dir)
print('downloading {0}...'.format(url))
resp = requests.get(url)
if resp:
with open(os.path.join(models_dir, os.path.basename(url)), "wb") as file:
file.write(resp.content)
else:
print('failed to download {0}: {1}', url, resp.text)
if __name__ == '__main__':
ComfyConfigure().cmdloop()

View File

@ -824,12 +824,12 @@ class ImageScale:
return (s,)
class ESRGAN:
models = ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B", "RealESRGAN_x2plus", "realesr-animevideov3", "realesr-general-x4v3"]
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "realesrgan")
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"model": (s.models,),
"model": (filter_files_extensions(recursive_search(s.models_dir), '.pth'), ),
"scale": ("FLOAT", {"default": 2.0, "min": 2.0, "max": 4.0, "step": 2.0}),
"face_restore": ("FACE_RESTORE_MODEL",)
}}
@ -843,7 +843,7 @@ class ESRGAN:
net = self.get_net(model=model)
upsampler = RealESRGANer(
scale = net.scale if hasattr(net, "scale") else net.upscale,
model_path = self.get_path(model=model),
model_path = os.path.join(self.models_dir, model),
model = net
)
outputs = []
@ -859,7 +859,7 @@ class ESRGAN:
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from basicsr.archs.rrdbnet_arch import RRDBNet
match model:
match os.path.splitext(model)[0]:
case "RealESRGAN_x4plus" | "RealESRNet_x4plus":
return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
case "RealESRGAN_x4plus_anime_6B":
@ -870,40 +870,30 @@ class ESRGAN:
return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
case "realesr-general-x4v3":
return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
def get_path(self, model):
match model:
case "RealESRGAN_x4plus":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
case "RealESRNet_x4plus":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
case "RealESRGAN_x4plus_anime_6B":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
case "RealESRGAN_x2plus":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
case "realesr-animevideov3":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth"
case "realesr-general-x4v3":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
case other:
print('Unknown model {0}, defaulting to RRDBNET...'.format(other))
return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
class GFPGAN:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "gfpgan")
@classmethod
def INPUT_TYPES(s):
return {"required": { "model_path": ("STRING", {"default": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"}),
return {"required": { "model": (filter_files_extensions(recursive_search(s.models_dir), '.pth'), ),
"weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01})
}}
RETURN_TYPES = ("FACE_RESTORE_MODEL",)
FUNCTION = "create"
CATEGORY = "image"
def create(self, model_path, weight):
return (lambda image, upscaler, scale: self.callback(image, upscaler, scale, model_path, weight),)
def create(self, model, weight):
return (lambda image, upscaler, scale: self.callback(image, upscaler, scale, model, weight),)
def callback(self, image, upscaler, scale, model_path, weight):
def callback(self, image, upscaler, scale, model, weight):
from gfpgan import GFPGANer
enhancer = GFPGANer(
model_path=model_path,
model_path=os.path.join(self.models_dir, model),
upscale=scale,
arch='clean',
channel_multiplier=2,

View File

@ -11,3 +11,4 @@ aiohttp
accelerate
realesrgan
gfpgan
requests