From 2b85c4d3dd35cc86a6802f4bec73b71c5281dd67 Mon Sep 17 00:00:00 2001 From: franky Date: Sat, 18 Mar 2023 09:35:15 +0100 Subject: [PATCH] Created a PathServer and added a1111 path option to use models saved on a1111 directory --- main.py | 15 +++++++++- nodes.py | 78 ++++++++++++++++++------------------------------ path_server.py | 37 +++++++++++++++++++++++ requirements.txt | 2 ++ 4 files changed, 82 insertions(+), 50 deletions(-) create mode 100644 path_server.py diff --git a/main.py b/main.py index 3c03381d6..aad0a8242 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,8 @@ import shutil import threading import asyncio +from path_server import PathServer + if os.name == "nt": import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) @@ -83,7 +85,18 @@ if __name__ == "__main__": try: p_index = sys.argv.index('--port') port = int(sys.argv[p_index + 1]) - except: + except ValueError as e: + pass + + try: + print(f"{sys.argv = }") + a1111_index = sys.argv.index('--a1111') + print(f"{a1111_index = }") + a1111 = sys.argv[a1111_index + 1] + print(f"{a1111 = }") + PathServer().set_a1111_path(a1111) + print(f"{PathServer().get('checkpoints') = }") + except ValueError as e: pass if '--quick-test-for-ci' in sys.argv: diff --git a/nodes.py b/nodes.py index 9a878b441..b1c6f76c6 100644 --- a/nodes.py +++ b/nodes.py @@ -23,6 +23,8 @@ import comfy_extras.clip_vision import model_management import importlib +from path_server import PathServer + supported_ckpt_extensions = ['.ckpt', '.pth'] supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth'] try: @@ -207,32 +209,25 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, ) class CheckpointLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - config_dir = os.path.join(models_dir, "configs") - ckpt_dir = os.path.join(models_dir, "checkpoints") - embedding_directory = os.path.join(models_dir, "embeddings") - @classmethod def INPUT_TYPES(s): - return {"required": { "config_name": (filter_files_extensions(recursive_search(s.config_dir), '.yaml'), ), - "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), )}} + return {"required": { "config_name": (filter_files_extensions(recursive_search(PathServer().get('configs')), '.yaml'), ), + "ckpt_name": (filter_files_extensions(recursive_search(PathServer().get('checkpoints')), supported_ckpt_extensions), )}} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" CATEGORY = "loaders" def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True): - config_path = os.path.join(self.config_dir, config_name) - ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) - return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory) + config_path = os.path.join(PathServer().get('configs'), config_name) + ckpt_path = os.path.join(PathServer().get('checkpoints'), ckpt_name) + return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=PathServer().get('embeddings')) class CheckpointLoaderSimple: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - ckpt_dir = os.path.join(models_dir, "checkpoints") - @classmethod def INPUT_TYPES(s): - return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), ), + print(f"{PathServer().get('checkpoints') = }") + return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(PathServer().get('checkpoints')), supported_ckpt_extensions), ), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -240,8 +235,9 @@ class CheckpointLoaderSimple: CATEGORY = "loaders" def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): - ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) - out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory) + print(f"{PathServer().get('checkpoints') = }") + ckpt_path = os.path.join(PathServer().get('checkpoints'), ckpt_name) + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=PathServer().get('embeddings')) return out class CLIPSetLastLayer: @@ -261,13 +257,11 @@ class CLIPSetLastLayer: return (clip,) class LoraLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - lora_dir = os.path.join(models_dir, "loras") @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "clip": ("CLIP", ), - "lora_name": (filter_files_extensions(recursive_search(s.lora_dir), supported_pt_extensions), ), + "lora_name": (filter_files_extensions(recursive_search(PathServer().get('loras')), supported_pt_extensions), ), "strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} @@ -277,16 +271,14 @@ class LoraLoader: CATEGORY = "loaders" def load_lora(self, model, clip, lora_name, strength_model, strength_clip): - lora_path = os.path.join(self.lora_dir, lora_name) + lora_path = os.path.join(PathServer().get('loras'), lora_name) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) return (model_lora, clip_lora) class VAELoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - vae_dir = os.path.join(models_dir, "vae") @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (filter_files_extensions(recursive_search(s.vae_dir), supported_pt_extensions), )}} + return {"required": { "vae_name": (filter_files_extensions(recursive_search(PathServer().get('vae')), supported_pt_extensions), )}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" @@ -294,16 +286,14 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): - vae_path = os.path.join(self.vae_dir, vae_name) + vae_path = os.path.join(PathServer().get('vae'), vae_name) vae = comfy.sd.VAE(ckpt_path=vae_path) return (vae,) class ControlNetLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - controlnet_dir = os.path.join(models_dir, "controlnet") @classmethod def INPUT_TYPES(s): - return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}} + return {"required": { "control_net_name": (filter_files_extensions(recursive_search(PathServer().get('controlnet')), supported_pt_extensions), )}} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" @@ -311,17 +301,15 @@ class ControlNetLoader: CATEGORY = "loaders" def load_controlnet(self, control_net_name): - controlnet_path = os.path.join(self.controlnet_dir, control_net_name) + controlnet_path = os.path.join(PathServer().get('controlnet'), control_net_name) controlnet = comfy.sd.load_controlnet(controlnet_path) return (controlnet,) class DiffControlNetLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - controlnet_dir = os.path.join(models_dir, "controlnet") @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}} + "control_net_name": (filter_files_extensions(recursive_search(PathServer().get('controlnet')), supported_pt_extensions), )}} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" @@ -329,7 +317,7 @@ class DiffControlNetLoader: CATEGORY = "loaders" def load_controlnet(self, model, control_net_name): - controlnet_path = os.path.join(self.controlnet_dir, control_net_name) + controlnet_path = os.path.join(PathServer().get('controlnet'), control_net_name) controlnet = comfy.sd.load_controlnet(controlnet_path, model) return (controlnet,) @@ -361,11 +349,9 @@ class ControlNetApply: return (c, ) class T2IAdapterLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - t2i_adapter_dir = os.path.join(models_dir, "t2i_adapter") @classmethod def INPUT_TYPES(s): - return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(s.t2i_adapter_dir), supported_pt_extensions), )}} + return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(PathServer().get('t2i_adapter')), supported_pt_extensions), )}} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_t2i_adapter" @@ -373,16 +359,14 @@ class T2IAdapterLoader: CATEGORY = "loaders" def load_t2i_adapter(self, t2i_adapter_name): - t2i_path = os.path.join(self.t2i_adapter_dir, t2i_adapter_name) + t2i_path = os.path.join(PathServer().get('t2i_adapter'), t2i_adapter_name) t2i_adapter = comfy.sd.load_t2i_adapter(t2i_path) return (t2i_adapter,) class CLIPLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - clip_dir = os.path.join(models_dir, "clip") @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ), + return {"required": { "clip_name": (filter_files_extensions(recursive_search(PathServer().get('clip')), supported_pt_extensions), ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -390,16 +374,14 @@ class CLIPLoader: CATEGORY = "loaders" def load_clip(self, clip_name): - clip_path = os.path.join(self.clip_dir, clip_name) - clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory) + clip_path = os.path.join(PathServer().get('clip'), clip_name) + clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=PathServer().get('embeddings')) return (clip,) class CLIPVisionLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - clip_dir = os.path.join(models_dir, "clip_vision") @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ), + return {"required": { "clip_name": (filter_files_extensions(recursive_search(PathServer().get('clip_vision')), supported_pt_extensions), ), }} RETURN_TYPES = ("CLIP_VISION",) FUNCTION = "load_clip" @@ -407,7 +389,7 @@ class CLIPVisionLoader: CATEGORY = "loaders" def load_clip(self, clip_name): - clip_path = os.path.join(self.clip_dir, clip_name) + clip_path = os.path.join(PathServer().get('clip_vision'), clip_name) clip_vision = comfy_extras.clip_vision.load(clip_path) return (clip_vision,) @@ -427,11 +409,9 @@ class CLIPVisionEncode: return (output,) class StyleModelLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - style_model_dir = os.path.join(models_dir, "style_models") @classmethod def INPUT_TYPES(s): - return {"required": { "style_model_name": (filter_files_extensions(recursive_search(s.style_model_dir), supported_pt_extensions), )}} + return {"required": { "style_model_name": (filter_files_extensions(recursive_search(PathServer().get('style_models')), supported_pt_extensions), )}} RETURN_TYPES = ("STYLE_MODEL",) FUNCTION = "load_style_model" @@ -439,7 +419,7 @@ class StyleModelLoader: CATEGORY = "loaders" def load_style_model(self, style_model_name): - style_model_path = os.path.join(self.style_model_dir, style_model_name) + style_model_path = os.path.join(PathServer().get('style_models'), style_model_name) style_model = comfy.sd.load_style_model(style_model_path) return (style_model,) diff --git a/path_server.py b/path_server.py new file mode 100644 index 000000000..57c37273c --- /dev/null +++ b/path_server.py @@ -0,0 +1,37 @@ +from singleton_decorator import singleton +import os + +@singleton +class PathServer(): + + def __init__(self): + self.paths = { + 'checkpoints': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'checkpoints'), + 'clip': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'clip'), + 'clip_vision': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'clip_vision'), + 'configs': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'configs'), + 'controlnet': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'controlnet'), + 'embeddings': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'embeddings'), + 'loras': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'loras'), + 'style_models': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'style_models'), + 't2i_adapter': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 't2i_adapter'), + 'upscale_models': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'upscale_models'), + 'vae': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'vae'), + } + + def set_a1111_path(self, a1111_path): + self.paths['checkpoints'] = os.path.join(a1111_path, 'models', 'Stable-diffusion') + self.paths['clip'] = os.path.join(a1111_path, 'models', 'clip-interrogator') + # self.paths['clip_vision'] = os.path.join(a1111_path, 'models', '') + self.paths['configs'] = os.path.join(a1111_path, 'models', 'Stable-diffusion') + self.paths['controlnet'] = os.path.join(a1111_path, 'models', 'ControlNet') + self.paths['embeddings'] = os.path.join(a1111_path, 'embeddings') + self.paths['loras'] = os.path.join(a1111_path, 'models', 'Lora') + #self.paths['style_models'] = os.path.join(a1111_path, 'models', '') + #self.paths['t2i_adapter'] = os.path.join(a1111_path, 'models', '') + self.paths['upscale_models'] = os.path.join(a1111_path, 'models', 'ESRGAN') + self.paths['vae'] = os.path.join(a1111_path, 'models', 'VAE') + + def get(self, key): + return self.paths[key] + diff --git a/requirements.txt b/requirements.txt index bc8b3c558..4590af371 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,5 @@ pytorch_lightning aiohttp accelerate pyyaml +singleton-decorator +