From 9a2add62ea6e9e9c9acf45dd05f63d92f058862a Mon Sep 17 00:00:00 2001 From: ExoPhantasm Date: Wed, 22 Feb 2023 10:17:52 +0000 Subject: [PATCH] Adds basic config file support --- main.py | 6 ++++++ nodes.py | 16 +++++++++------- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 7c72bc4e0..cb7a9a22e 100644 --- a/main.py +++ b/main.py @@ -43,6 +43,12 @@ if '--dont-upcast-attention' in sys.argv: import torch import nodes +def config(key, default=None): + with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.json"), "r", encoding="utf8") as file: + config = json.load(file) + config_item = config.get(key) if config.get(key) != "" else default + return config_item + def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} diff --git a/nodes.py b/nodes.py index 3bdad71be..46ce34020 100644 --- a/nodes.py +++ b/nodes.py @@ -21,6 +21,8 @@ import comfy.utils import model_management import importlib +from main import config + supported_ckpt_extensions = ['.ckpt', '.pth'] supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth'] try: @@ -159,9 +161,9 @@ class VAEEncodeForInpaint: class CheckpointLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - config_dir = os.path.join(models_dir, "configs") - ckpt_dir = os.path.join(models_dir, "checkpoints") - embedding_directory = os.path.join(models_dir, "embeddings") + config_dir = config("ckpt_cfg_path", os.path.join(models_dir, "configs")) + ckpt_dir = config("ckpt_path", os.path.join(models_dir, "checkpoints")) + embedding_directory = config("ti_path", os.path.join(models_dir, "embeddings")) @classmethod def INPUT_TYPES(s): @@ -179,7 +181,7 @@ class CheckpointLoader: class LoraLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - lora_dir = os.path.join(models_dir, "loras") + lora_dir = config("lora_path", os.path.join(models_dir, "loras")) @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), @@ -200,7 +202,7 @@ class LoraLoader: class VAELoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - vae_dir = os.path.join(models_dir, "vae") + vae_dir = config("vae_path", os.path.join(models_dir, "vae")) @classmethod def INPUT_TYPES(s): return {"required": { "vae_name": (filter_files_extensions(recursive_search(s.vae_dir), supported_pt_extensions), )}} @@ -217,7 +219,7 @@ class VAELoader: class ControlNetLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - controlnet_dir = os.path.join(models_dir, "controlnet") + controlnet_dir = config("controlnet_path", os.path.join(models_dir, "controlnet")) @classmethod def INPUT_TYPES(s): return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}} @@ -262,7 +264,7 @@ class ControlNetApply: class CLIPLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - clip_dir = os.path.join(models_dir, "clip") + clip_dir = config("clip_path", os.path.join(models_dir, "clip")) @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),