mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 13:20:19 +08:00
Load models locally and add a configure script
This commit is contained in:
parent
4e3e46a323
commit
22191f7faa
3
.gitignore
vendored
3
.gitignore
vendored
@ -4,4 +4,5 @@ output/
|
|||||||
models/checkpoints
|
models/checkpoints
|
||||||
models/vae
|
models/vae
|
||||||
models/embeddings
|
models/embeddings
|
||||||
gfpgan/
|
models/gfpgan
|
||||||
|
models/realesrgan
|
||||||
|
|||||||
33
configure.py
Normal file
33
configure.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import cmd, requests, os
|
||||||
|
|
||||||
|
class ComfyConfigure(cmd.Cmd):
|
||||||
|
intro = "Welcome to ComfyUI configure shell. Type help or ? to list commands.\n"
|
||||||
|
prompt = "(configure) "
|
||||||
|
file = None
|
||||||
|
|
||||||
|
def do_install_esrgan_deps(self, arg):
|
||||||
|
'Install base ESRGAN/GFPGAN model dependencies'
|
||||||
|
self.install_model('realesrgan', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth')
|
||||||
|
self.install_model('realesrgan', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth')
|
||||||
|
self.install_model('gfpgan', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth')
|
||||||
|
print('done!')
|
||||||
|
|
||||||
|
def do_exit(self, arg):
|
||||||
|
'Exit the shell'
|
||||||
|
return True
|
||||||
|
|
||||||
|
def install_model(self, category, url):
|
||||||
|
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", category)
|
||||||
|
if not os.path.isdir(models_dir):
|
||||||
|
os.mkdir(models_dir)
|
||||||
|
|
||||||
|
print('downloading {0}...'.format(url))
|
||||||
|
resp = requests.get(url)
|
||||||
|
if resp:
|
||||||
|
with open(os.path.join(models_dir, os.path.basename(url)), "wb") as file:
|
||||||
|
file.write(resp.content)
|
||||||
|
else:
|
||||||
|
print('failed to download {0}: {1}', url, resp.text)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
ComfyConfigure().cmdloop()
|
||||||
38
nodes.py
38
nodes.py
@ -824,12 +824,12 @@ class ImageScale:
|
|||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
class ESRGAN:
|
class ESRGAN:
|
||||||
models = ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B", "RealESRGAN_x2plus", "realesr-animevideov3", "realesr-general-x4v3"]
|
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "realesrgan")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "image": ("IMAGE",),
|
return {"required": { "image": ("IMAGE",),
|
||||||
"model": (s.models,),
|
"model": (filter_files_extensions(recursive_search(s.models_dir), '.pth'), ),
|
||||||
"scale": ("FLOAT", {"default": 2.0, "min": 2.0, "max": 4.0, "step": 2.0}),
|
"scale": ("FLOAT", {"default": 2.0, "min": 2.0, "max": 4.0, "step": 2.0}),
|
||||||
"face_restore": ("FACE_RESTORE_MODEL",)
|
"face_restore": ("FACE_RESTORE_MODEL",)
|
||||||
}}
|
}}
|
||||||
@ -843,7 +843,7 @@ class ESRGAN:
|
|||||||
net = self.get_net(model=model)
|
net = self.get_net(model=model)
|
||||||
upsampler = RealESRGANer(
|
upsampler = RealESRGANer(
|
||||||
scale = net.scale if hasattr(net, "scale") else net.upscale,
|
scale = net.scale if hasattr(net, "scale") else net.upscale,
|
||||||
model_path = self.get_path(model=model),
|
model_path = os.path.join(self.models_dir, model),
|
||||||
model = net
|
model = net
|
||||||
)
|
)
|
||||||
outputs = []
|
outputs = []
|
||||||
@ -859,7 +859,7 @@ class ESRGAN:
|
|||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
|
||||||
match model:
|
match os.path.splitext(model)[0]:
|
||||||
case "RealESRGAN_x4plus" | "RealESRNet_x4plus":
|
case "RealESRGAN_x4plus" | "RealESRNet_x4plus":
|
||||||
return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||||
case "RealESRGAN_x4plus_anime_6B":
|
case "RealESRGAN_x4plus_anime_6B":
|
||||||
@ -870,40 +870,30 @@ class ESRGAN:
|
|||||||
return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
||||||
case "realesr-general-x4v3":
|
case "realesr-general-x4v3":
|
||||||
return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
return SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||||
|
case other:
|
||||||
def get_path(self, model):
|
print('Unknown model {0}, defaulting to RRDBNET...'.format(other))
|
||||||
match model:
|
return RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||||
case "RealESRGAN_x4plus":
|
|
||||||
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
|
|
||||||
case "RealESRNet_x4plus":
|
|
||||||
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
|
|
||||||
case "RealESRGAN_x4plus_anime_6B":
|
|
||||||
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
|
|
||||||
case "RealESRGAN_x2plus":
|
|
||||||
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
|
|
||||||
case "realesr-animevideov3":
|
|
||||||
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth"
|
|
||||||
case "realesr-general-x4v3":
|
|
||||||
return "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
|
|
||||||
|
|
||||||
class GFPGAN:
|
class GFPGAN:
|
||||||
|
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", "gfpgan")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "model_path": ("STRING", {"default": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"}),
|
return {"required": { "model": (filter_files_extensions(recursive_search(s.models_dir), '.pth'), ),
|
||||||
"weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01})
|
"weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("FACE_RESTORE_MODEL",)
|
RETURN_TYPES = ("FACE_RESTORE_MODEL",)
|
||||||
FUNCTION = "create"
|
FUNCTION = "create"
|
||||||
CATEGORY = "image"
|
CATEGORY = "image"
|
||||||
|
|
||||||
def create(self, model_path, weight):
|
def create(self, model, weight):
|
||||||
return (lambda image, upscaler, scale: self.callback(image, upscaler, scale, model_path, weight),)
|
return (lambda image, upscaler, scale: self.callback(image, upscaler, scale, model, weight),)
|
||||||
|
|
||||||
def callback(self, image, upscaler, scale, model_path, weight):
|
def callback(self, image, upscaler, scale, model, weight):
|
||||||
from gfpgan import GFPGANer
|
from gfpgan import GFPGANer
|
||||||
|
|
||||||
enhancer = GFPGANer(
|
enhancer = GFPGANer(
|
||||||
model_path=model_path,
|
model_path=os.path.join(self.models_dir, model),
|
||||||
upscale=scale,
|
upscale=scale,
|
||||||
arch='clean',
|
arch='clean',
|
||||||
channel_multiplier=2,
|
channel_multiplier=2,
|
||||||
|
|||||||
@ -11,3 +11,4 @@ aiohttp
|
|||||||
accelerate
|
accelerate
|
||||||
realesrgan
|
realesrgan
|
||||||
gfpgan
|
gfpgan
|
||||||
|
requests
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user