Load models locally and add a configure script

This commit is contained in:
jekky 2023-03-04 13:45:43 +00:00
parent 4e3e46a323
commit 22191f7faa
4 changed files with 50 additions and 25 deletions

3
.gitignore vendored
View File

@ -4,4 +4,5 @@ output/
models/checkpoints models/checkpoints
models/vae models/vae
models/embeddings models/embeddings
gfpgan/ models/gfpgan
models/realesrgan

33
configure.py Normal file
View File

@ -0,0 +1,33 @@
import cmd, requests, os
class ComfyConfigure(cmd.Cmd):
intro = "Welcome to ComfyUI configure shell. Type help or ? to list commands.\n"
prompt = "(configure) "
file = None
def do_install_esrgan_deps(self, arg):
'Install base ESRGAN/GFPGAN model dependencies'
self.install_model('realesrgan', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth')
self.install_model('realesrgan', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth')
self.install_model('gfpgan', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth')
print('done!')
def do_exit(self, arg):
'Exit the shell'
return True
def install_model(self, category, url):
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", category)
if not os.path.isdir(models_dir):
os.mkdir(models_dir)
print('downloading {0}...'.format(url))
resp = requests.get(url)
if resp:
with open(os.path.join(models_dir, os.path.basename(url)), "wb") as file:
file.write(resp.content)
else:
print('failed to download {0}: {1}', url, resp.text)
if __name__ == '__main__':
ComfyConfigure().cmdloop()

View File

@ -824,12 +824,12 @@ class ImageScale:
return (s,) return (s,)
class ESRGAN: class ESRGAN:
models = ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B", "RealESRGAN_x2plus", "realesr-animevideov3", "realesr-general-x4v3"] models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "realesrgan")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",), return {"required": { "image": ("IMAGE",),
"model": (s.models,), "model": (filter_files_extensions(recursive_search(s.models_dir), '.pth'), ),
"scale": ("FLOAT", {"default": 2.0, "min": 2.0, "max": 4.0, "step": 2.0}), "scale": ("FLOAT", {"default": 2.0, "min": 2.0, "max": 4.0, "step": 2.0}),
"face_restore": ("FACE_RESTORE_MODEL",) "face_restore": ("FACE_RESTORE_MODEL",)
}} }}
@ -843,7 +843,7 @@ class ESRGAN:
net = self.get_net(model=model) net = self.get_net(model=model)
upsampler = RealESRGANer( upsampler = RealESRGANer(
scale = net.scale if hasattr(net, "scale") else net.upscale, scale = net.scale if hasattr(net, "scale") else net.upscale,
model_path = self.get_path(model=model), model_path = os.path.join(self.models_dir, model),
model = net model = net
) )
outputs = [] outputs = []
@ -859,7 +859,7 @@ class ESRGAN:
from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
match model: match os.path.splitext(model)[0]:
case "RealESRGAN_x4plus" | "RealESRNet_x4plus": case "RealESRGAN_x4plus" | "RealESRNet_x4plus":
return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
case "RealESRGAN_x4plus_anime_6B": case "RealESRGAN_x4plus_anime_6B":
@ -870,40 +870,30 @@ class ESRGAN:
return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
case "realesr-general-x4v3": case "realesr-general-x4v3":
return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
case other:
def get_path(self, model): print('Unknown model {0}, defaulting to RRDBNET...'.format(other))
match model: return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
case "RealESRGAN_x4plus":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
case "RealESRNet_x4plus":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
case "RealESRGAN_x4plus_anime_6B":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
case "RealESRGAN_x2plus":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
case "realesr-animevideov3":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth"
case "realesr-general-x4v3":
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
class GFPGAN: class GFPGAN:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "gfpgan")
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model_path": ("STRING", {"default": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"}), return {"required": { "model": (filter_files_extensions(recursive_search(s.models_dir), '.pth'), ),
"weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}) "weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01})
}} }}
RETURN_TYPES = ("FACE_RESTORE_MODEL",) RETURN_TYPES = ("FACE_RESTORE_MODEL",)
FUNCTION = "create" FUNCTION = "create"
CATEGORY = "image" CATEGORY = "image"
def create(self, model_path, weight): def create(self, model, weight):
return (lambda image, upscaler, scale: self.callback(image, upscaler, scale, model_path, weight),) return (lambda image, upscaler, scale: self.callback(image, upscaler, scale, model, weight),)
def callback(self, image, upscaler, scale, model_path, weight): def callback(self, image, upscaler, scale, model, weight):
from gfpgan import GFPGANer from gfpgan import GFPGANer
enhancer = GFPGANer( enhancer = GFPGANer(
model_path=model_path, model_path=os.path.join(self.models_dir, model),
upscale=scale, upscale=scale,
arch='clean', arch='clean',
channel_multiplier=2, channel_multiplier=2,

View File

@ -11,3 +11,4 @@ aiohttp
accelerate accelerate
realesrgan realesrgan
gfpgan gfpgan
requests