Merge model config into config.yaml

This commit is contained in:
space-nuko 2023-06-01 10:02:15 -05:00
parent ae091ccf82
commit 3d1c6b4db1
3 changed files with 123 additions and 48 deletions

View File

@ -1,34 +1,68 @@
import os.path
import pprint
import argparse
import yaml
import ruamel.yaml
import folder_paths
yaml = ruamel.yaml.YAML()
yaml.default_flow_style = False
yaml.sort_keys = False
CM = ruamel.yaml.comments.CommentedMap
class AbstractOptionInfo:
"""
A single option that can be saved to the config file YAML. Can potentially
comprise more than one command line arg/flag like mutually exclusive flag
groups being condensed into an enum in the config file
"""
name: str
raw_output: bool
def add_argument(self, parser):
"""
Adds an argument to the argparse parser
"""
pass
def get_arg_defaults(self, parser):
"""
Returns the expected argparse namespaced options as a dictionary from
querying the parser's default value
"""
pass
def get_help(self):
pass
def convert_to_args_array(self, value):
"""
Interprets this option and a value as a string array of argstrings
If it's a flag returns ["--flag"] for True and [] for False, otherwise
can return ["--option", str(value)] or similar
Alternatively if raw_output is True the parser will skip the parsing
step and use the value as if it were returned from parse_known_args()
"""
pass
def convert_to_file_option(self, parser, args):
pass
"""
Converts a portion of the args to a value to be serialized to the config
file
def save(self, yaml):
pass
def load(self, yaml):
As an example vram options are mutually exclusive, so it's made to look
like an enum in the config file. So having a --lowvram flag in the args
gets translated to "vram: 'lowvram'" in YAML
"""
pass
class OptionInfo(AbstractOptionInfo):
def __init__(self, name, **parser_args):
self.name = name
self.parser_args = parser_args
self.raw_output = False
def __repr__(self):
return f'OptionInfo(\'{self.name}\', {pprint.pformat(self.parser_args)})'
@ -39,24 +73,22 @@ class OptionInfo(AbstractOptionInfo):
def get_arg_defaults(self, parser):
return { self.name: parser.get_default(self.name) }
def get_help(self):
return self.parser_args.get("help")
def convert_to_args_array(self, value):
if value is not None:
return [f"--{self.name}", str(value)]
return []
def convert_to_file_option(self, parser, args):
return args.get(self.name, parser.get_default(self.name))
def save(self, yaml):
pass
def load(self, yaml):
pass
return args.get(self.name.replace("-", "_"), parser.get_default(self.name))
class OptionInfoFlag(AbstractOptionInfo):
def __init__(self, name, **parser_args):
self.name = name
self.parser_args = parser_args
self.raw_output = False
def __repr__(self):
return f'OptionInfoFlag(\'{self.name}\', {pprint.pformat(self.parser_args)})'
@ -67,24 +99,24 @@ class OptionInfoFlag(AbstractOptionInfo):
def get_arg_defaults(self, parser):
return { self.name: parser.get_default(self.name) or False }
def get_help(self):
return self.parser_args.get("help")
def convert_to_args_array(self, value):
if value:
return [f"--{self.name}"]
return []
def convert_to_file_option(self, parser, args):
return args.get(self.name, parser.get_default(self.name) or False)
def save(self, yaml):
pass
def load(self, yaml):
pass
return args.get(self.name.replace("-", "_"), parser.get_default(self.name) or False)
class OptionInfoEnum(AbstractOptionInfo):
def __init__(self, name, options):
def __init__(self, name, options, help=None):
self.name = name
self.options = options
self.help = help
self.parser_args = {}
self.raw_output = False
def __repr__(self):
return f'OptionInfoEnum(\'{self.name}\', {pprint.pformat(self.options)})'
@ -97,6 +129,9 @@ class OptionInfoEnum(AbstractOptionInfo):
def get_arg_defaults(self, parser):
return {} # treat as no flag in the group being passed
def get_help(self):
return self.help
def convert_to_args_array(self, file_value):
for option in self.options:
if option.name == file_value:
@ -105,13 +140,10 @@ class OptionInfoEnum(AbstractOptionInfo):
def convert_to_file_option(self, parser, args):
for option in self.options:
if args.get(option.option_name) is True:
if args.get(option.option_name.replace("-", "_")) is True:
return option.name
return None
def load_from_yaml(self, yaml):
pass
class OptionInfoEnumChoice:
name: str
option_name: str
@ -129,19 +161,43 @@ class OptionInfoEnumChoice:
def __repr__(self):
return f'OptionInfoEnumChoice(\'{self.name}\', \'{self.option_name}\', \'{self.help}\')'
class OptionInfoRaw:
"""
Raw YAML input and output, ignores argparse entirely
"""
def __init__(self, name, help=None):
self.name = name
self.help = help
self.parser_args = {}
self.raw_output = True
def add_argument(self, parser):
pass
def get_help(self):
return self.help
def get_arg_defaults(self, parser):
return { self.name: {} }
def convert_to_args_array(self, value):
return value
def convert_to_file_option(self, parser, args):
return args.get(self.name.replace("-", "_"), {})
CONFIG_OPTIONS = [
("network", [
OptionInfo("listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0",
help="Specify the IP address to listen on (default: 127.0.0.1). \
If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)"),
help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)"),
OptionInfo("port", type=int, default=8188, help="Set the listen port."),
OptionInfo("enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*",
help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'."),
]),
("files", [
OptionInfo("extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append',
help="Load one or more extra_model_paths.yaml files."),
OptionInfoRaw("extra-model-paths-config", help="Extra paths to scan for model files."),
OptionInfo("output-directory", type=str, default=None, help="Set the ComfyUI output directory."),
]),
("behavior", [
@ -166,7 +222,7 @@ CONFIG_OPTIONS = [
OptionInfoEnum("cross-attention", [
OptionInfoEnumChoice("split", option_name="use-split-cross-attention", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory."),
OptionInfoEnumChoice("pytorch", option_name="use-pytorch-cross-attention", help="Used to force normal vram use if lowvram gets automatically enabled."),
]),
], help="Type of cross attention to use"),
OptionInfoFlag("disable-xformers",
help="Disable xformers."),
OptionInfoEnum("vram", [
@ -175,7 +231,7 @@ CONFIG_OPTIONS = [
OptionInfoEnumChoice("lowvram", help="Split the unet in parts to use less vram."),
OptionInfoEnumChoice("novram", help="When lowvram isn't enough."),
OptionInfoEnumChoice("cpu", help="To use the CPU for everything (slow).")
])
], help="Determines how VRAM is used.")
])
]
@ -224,7 +280,7 @@ class ComfyConfigLoader:
def load_from_file(self, yaml_path):
with open(yaml_path, 'r') as stream:
raw_config = yaml.safe_load(stream)
raw_config = yaml.load(stream)
config = {}
root = raw_config.get("config", {})
@ -235,12 +291,20 @@ class ComfyConfigLoader:
known_args = []
for k, v in from_file.items():
option_info = next((o for o in options if o.name == k), None)
kebab_k = k.replace("_", "-")
option_info = next((o for o in options if o.name == kebab_k), None)
if option_info is not None:
known_args = option_info.convert_to_args_array(v)
parsed, rest = self.parser.parse_known_args(known_args)
for k, v in vars(parsed).items():
k = k.replace("-", "_")
if option_info.raw_output:
parsed = argparse.Namespace(**{ k: known_args })
rest = None
else:
parsed, rest = self.parser.parse_known_args(known_args)
print("---------------------")
print(option_info.name)
print(known_args)
item = vars(parsed).get(k)
if item is not None:
config[k] = v
if rest:
@ -249,20 +313,24 @@ class ComfyConfigLoader:
return config
def convert_args_to_options(self, args):
if isinstance(args, argparse.Namespace):
args = vars(args)
config = {}
import pprint; pprint.pp(args)
for category, options in self.option_infos:
d = {}
d = CM()
for option in options:
k = option.name.replace('-', '_')
d[k] = option.convert_to_file_option(self.parser, args)
help = option.get_help()
if help is not None:
d.yaml_set_comment_before_after_key(k, "\n" + help, indent=4)
config[category] = d
return { "config": config }
def save_config(self, args):
options = self.convert_args_to_options(args)
with open(self.config_path, 'w') as f:
yaml.dump(options, f, default_flow_style=False, sort_keys=False)
yaml.dump(options, f)
def get_cli_arguments(self):
return vars(self.parser.parse_args())
@ -277,9 +345,10 @@ class ComfyConfigLoader:
config_options = self.load_from_file(self.config_path)
config_options = dict(merge_dicts(defaults, config_options))
self.save_config(defaults)
self.save_config(config_options)
cli_args = self.get_cli_arguments()
print(cli_args)
args = dict(merge_dicts(config_options, cli_args))
return argparse.Namespace(**args)
@ -290,3 +359,9 @@ args = config_loader.parse_args()
if args.windows_standalone_build:
args.auto_launch = True
import pprint; pprint.pp(args)
import pprint; pprint.pp(config_loader.convert_args_to_options(args))
exit(1)

14
main.py
View File

@ -49,16 +49,17 @@ def cleanup_temp():
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
def load_extra_path_config(yaml_path):
def load_extra_path_config_file(yaml_path):
with open(yaml_path, 'r') as stream:
config = yaml.safe_load(stream)
load_extra_path_config(config)
def load_extra_path_config(config):
for c in config:
conf = config[c]
if conf is None:
continue
base_path = None
if "base_path" in conf:
base_path = conf.pop("base_path")
base_path = conf.get("base_path", None)
for x in conf:
for y in conf[x].split("\n"):
if len(y) == 0:
@ -79,11 +80,10 @@ if __name__ == "__main__":
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path)
load_extra_path_config_file(extra_model_paths_config_path)
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path)
load_extra_path_config(args.extra_model_paths_config)
init_custom_nodes()
server.add_routes()

View File

@ -8,4 +8,4 @@ safetensors>=0.3.0
pytorch_lightning
aiohttp
accelerate
pyyaml
ruamel.yaml