From 3d1c6b4db1473404cf076f3dd4b838ffffdbd86b Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 1 Jun 2023 10:02:15 -0500 Subject: [PATCH] Merge model config into config.yaml --- comfy/cli_args.py | 155 ++++++++++++++++++++++++++++++++++------------ main.py | 14 ++--- requirements.txt | 2 +- 3 files changed, 123 insertions(+), 48 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 0ce159112..07c413de6 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -1,34 +1,68 @@ import os.path import pprint import argparse -import yaml +import ruamel.yaml import folder_paths +yaml = ruamel.yaml.YAML() +yaml.default_flow_style = False +yaml.sort_keys = False +CM = ruamel.yaml.comments.CommentedMap + class AbstractOptionInfo: + """ + A single option that can be saved to the config file YAML. Can potentially + comprise more than one command line arg/flag like mutually exclusive flag + groups being condensed into an enum in the config file + """ + name: str + raw_output: bool def add_argument(self, parser): + """ + Adds an argument to the argparse parser + """ pass def get_arg_defaults(self, parser): + """ + Returns the expected argparse namespaced options as a dictionary from + querying the parser's default value + """ + pass + + def get_help(self): pass def convert_to_args_array(self, value): + """ + Interprets this option and a value as a string array of argstrings + + If it's a flag returns ["--flag"] for True and [] for False, otherwise + can return ["--option", str(value)] or similar + + Alternatively if raw_output is True the parser will skip the parsing + step and use the value as if it were returned from parse_known_args() + """ pass def convert_to_file_option(self, parser, args): - pass + """ + Converts a portion of the args to a value to be serialized to the config + file - def save(self, yaml): - pass - - def load(self, yaml): + As an example vram options are mutually exclusive, so it's made to look + like an enum in the config file. So having a --lowvram flag in the args + gets translated to "vram: 'lowvram'" in YAML + """ pass class OptionInfo(AbstractOptionInfo): def __init__(self, name, **parser_args): self.name = name self.parser_args = parser_args + self.raw_output = False def __repr__(self): return f'OptionInfo(\'{self.name}\', {pprint.pformat(self.parser_args)})' @@ -39,24 +73,22 @@ class OptionInfo(AbstractOptionInfo): def get_arg_defaults(self, parser): return { self.name: parser.get_default(self.name) } + def get_help(self): + return self.parser_args.get("help") + def convert_to_args_array(self, value): if value is not None: return [f"--{self.name}", str(value)] return [] def convert_to_file_option(self, parser, args): - return args.get(self.name, parser.get_default(self.name)) - - def save(self, yaml): - pass - - def load(self, yaml): - pass + return args.get(self.name.replace("-", "_"), parser.get_default(self.name)) class OptionInfoFlag(AbstractOptionInfo): def __init__(self, name, **parser_args): self.name = name self.parser_args = parser_args + self.raw_output = False def __repr__(self): return f'OptionInfoFlag(\'{self.name}\', {pprint.pformat(self.parser_args)})' @@ -67,24 +99,24 @@ class OptionInfoFlag(AbstractOptionInfo): def get_arg_defaults(self, parser): return { self.name: parser.get_default(self.name) or False } + def get_help(self): + return self.parser_args.get("help") + def convert_to_args_array(self, value): if value: return [f"--{self.name}"] return [] def convert_to_file_option(self, parser, args): - return args.get(self.name, parser.get_default(self.name) or False) - - def save(self, yaml): - pass - - def load(self, yaml): - pass + return args.get(self.name.replace("-", "_"), parser.get_default(self.name) or False) class OptionInfoEnum(AbstractOptionInfo): - def __init__(self, name, options): + def __init__(self, name, options, help=None): self.name = name self.options = options + self.help = help + self.parser_args = {} + self.raw_output = False def __repr__(self): return f'OptionInfoEnum(\'{self.name}\', {pprint.pformat(self.options)})' @@ -97,6 +129,9 @@ class OptionInfoEnum(AbstractOptionInfo): def get_arg_defaults(self, parser): return {} # treat as no flag in the group being passed + def get_help(self): + return self.help + def convert_to_args_array(self, file_value): for option in self.options: if option.name == file_value: @@ -105,13 +140,10 @@ class OptionInfoEnum(AbstractOptionInfo): def convert_to_file_option(self, parser, args): for option in self.options: - if args.get(option.option_name) is True: + if args.get(option.option_name.replace("-", "_")) is True: return option.name return None - def load_from_yaml(self, yaml): - pass - class OptionInfoEnumChoice: name: str option_name: str @@ -129,19 +161,43 @@ class OptionInfoEnumChoice: def __repr__(self): return f'OptionInfoEnumChoice(\'{self.name}\', \'{self.option_name}\', \'{self.help}\')' +class OptionInfoRaw: + """ + Raw YAML input and output, ignores argparse entirely + """ + + def __init__(self, name, help=None): + self.name = name + self.help = help + self.parser_args = {} + self.raw_output = True + + def add_argument(self, parser): + pass + + def get_help(self): + return self.help + + def get_arg_defaults(self, parser): + return { self.name: {} } + + def convert_to_args_array(self, value): + return value + + def convert_to_file_option(self, parser, args): + return args.get(self.name.replace("-", "_"), {}) + CONFIG_OPTIONS = [ ("network", [ OptionInfo("listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", - help="Specify the IP address to listen on (default: 127.0.0.1). \ - If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)"), + help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)"), OptionInfo("port", type=int, default=8188, help="Set the listen port."), OptionInfo("enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'."), ]), ("files", [ - OptionInfo("extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', - help="Load one or more extra_model_paths.yaml files."), + OptionInfoRaw("extra-model-paths-config", help="Extra paths to scan for model files."), OptionInfo("output-directory", type=str, default=None, help="Set the ComfyUI output directory."), ]), ("behavior", [ @@ -166,7 +222,7 @@ CONFIG_OPTIONS = [ OptionInfoEnum("cross-attention", [ OptionInfoEnumChoice("split", option_name="use-split-cross-attention", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory."), OptionInfoEnumChoice("pytorch", option_name="use-pytorch-cross-attention", help="Used to force normal vram use if lowvram gets automatically enabled."), - ]), + ], help="Type of cross attention to use"), OptionInfoFlag("disable-xformers", help="Disable xformers."), OptionInfoEnum("vram", [ @@ -175,7 +231,7 @@ CONFIG_OPTIONS = [ OptionInfoEnumChoice("lowvram", help="Split the unet in parts to use less vram."), OptionInfoEnumChoice("novram", help="When lowvram isn't enough."), OptionInfoEnumChoice("cpu", help="To use the CPU for everything (slow).") - ]) + ], help="Determines how VRAM is used.") ]) ] @@ -224,7 +280,7 @@ class ComfyConfigLoader: def load_from_file(self, yaml_path): with open(yaml_path, 'r') as stream: - raw_config = yaml.safe_load(stream) + raw_config = yaml.load(stream) config = {} root = raw_config.get("config", {}) @@ -235,12 +291,20 @@ class ComfyConfigLoader: known_args = [] for k, v in from_file.items(): - option_info = next((o for o in options if o.name == k), None) + kebab_k = k.replace("_", "-") + option_info = next((o for o in options if o.name == kebab_k), None) if option_info is not None: known_args = option_info.convert_to_args_array(v) - parsed, rest = self.parser.parse_known_args(known_args) - for k, v in vars(parsed).items(): - k = k.replace("-", "_") + if option_info.raw_output: + parsed = argparse.Namespace(**{ k: known_args }) + rest = None + else: + parsed, rest = self.parser.parse_known_args(known_args) + print("---------------------") + print(option_info.name) + print(known_args) + item = vars(parsed).get(k) + if item is not None: config[k] = v if rest: @@ -249,20 +313,24 @@ class ComfyConfigLoader: return config def convert_args_to_options(self, args): + if isinstance(args, argparse.Namespace): + args = vars(args) config = {} - import pprint; pprint.pp(args) for category, options in self.option_infos: - d = {} + d = CM() for option in options: k = option.name.replace('-', '_') d[k] = option.convert_to_file_option(self.parser, args) + help = option.get_help() + if help is not None: + d.yaml_set_comment_before_after_key(k, "\n" + help, indent=4) config[category] = d return { "config": config } def save_config(self, args): options = self.convert_args_to_options(args) with open(self.config_path, 'w') as f: - yaml.dump(options, f, default_flow_style=False, sort_keys=False) + yaml.dump(options, f) def get_cli_arguments(self): return vars(self.parser.parse_args()) @@ -277,9 +345,10 @@ class ComfyConfigLoader: config_options = self.load_from_file(self.config_path) config_options = dict(merge_dicts(defaults, config_options)) - self.save_config(defaults) + self.save_config(config_options) cli_args = self.get_cli_arguments() + print(cli_args) args = dict(merge_dicts(config_options, cli_args)) return argparse.Namespace(**args) @@ -290,3 +359,9 @@ args = config_loader.parse_args() if args.windows_standalone_build: args.auto_launch = True + + +import pprint; pprint.pp(args) +import pprint; pprint.pp(config_loader.convert_args_to_options(args)) + +exit(1) diff --git a/main.py b/main.py index 50d3b9a62..d5bdf1cae 100644 --- a/main.py +++ b/main.py @@ -49,16 +49,17 @@ def cleanup_temp(): if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) -def load_extra_path_config(yaml_path): +def load_extra_path_config_file(yaml_path): with open(yaml_path, 'r') as stream: config = yaml.safe_load(stream) + load_extra_path_config(config) + +def load_extra_path_config(config): for c in config: conf = config[c] if conf is None: continue - base_path = None - if "base_path" in conf: - base_path = conf.pop("base_path") + base_path = conf.get("base_path", None) for x in conf: for y in conf[x].split("\n"): if len(y) == 0: @@ -79,11 +80,10 @@ if __name__ == "__main__": extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") if os.path.isfile(extra_model_paths_config_path): - load_extra_path_config(extra_model_paths_config_path) + load_extra_path_config_file(extra_model_paths_config_path) if args.extra_model_paths_config: - for config_path in itertools.chain(*args.extra_model_paths_config): - load_extra_path_config(config_path) + load_extra_path_config(args.extra_model_paths_config) init_custom_nodes() server.add_routes() diff --git a/requirements.txt b/requirements.txt index 0527b31df..c2e32b978 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,4 @@ safetensors>=0.3.0 pytorch_lightning aiohttp accelerate -pyyaml +ruamel.yaml