Adds basic config file support

This commit is contained in:
ExoPhantasm 2023-02-22 10:17:52 +00:00
parent 2976c1ad28
commit 9a2add62ea
2 changed files with 15 additions and 7 deletions

View File

@ -43,6 +43,12 @@ if '--dont-upcast-attention' in sys.argv:
import torch
import nodes
def config(key, default=None):
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.json"), "r", encoding="utf8") as file:
config = json.load(file)
config_item = config.get(key) if config.get(key) != "" else default
return config_item
def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}

View File

@ -21,6 +21,8 @@ import comfy.utils
import model_management
import importlib
from main import config
supported_ckpt_extensions = ['.ckpt', '.pth']
supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth']
try:
@ -159,9 +161,9 @@ class VAEEncodeForInpaint:
class CheckpointLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
config_dir = os.path.join(models_dir, "configs")
ckpt_dir = os.path.join(models_dir, "checkpoints")
embedding_directory = os.path.join(models_dir, "embeddings")
config_dir = config("ckpt_cfg_path", os.path.join(models_dir, "configs"))
ckpt_dir = config("ckpt_path", os.path.join(models_dir, "checkpoints"))
embedding_directory = config("ti_path", os.path.join(models_dir, "embeddings"))
@classmethod
def INPUT_TYPES(s):
@ -179,7 +181,7 @@ class CheckpointLoader:
class LoraLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
lora_dir = os.path.join(models_dir, "loras")
lora_dir = config("lora_path", os.path.join(models_dir, "loras"))
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
@ -200,7 +202,7 @@ class LoraLoader:
class VAELoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
vae_dir = os.path.join(models_dir, "vae")
vae_dir = config("vae_path", os.path.join(models_dir, "vae"))
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (filter_files_extensions(recursive_search(s.vae_dir), supported_pt_extensions), )}}
@ -217,7 +219,7 @@ class VAELoader:
class ControlNetLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
controlnet_dir = os.path.join(models_dir, "controlnet")
controlnet_dir = config("controlnet_path", os.path.join(models_dir, "controlnet"))
@classmethod
def INPUT_TYPES(s):
return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
@ -262,7 +264,7 @@ class ControlNetApply:
class CLIPLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
clip_dir = os.path.join(models_dir, "clip")
clip_dir = config("clip_path", os.path.join(models_dir, "clip"))
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),