From 143efb5900a150dbef98ccba7e6b10cce6a7eff7 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 1 Jun 2023 11:44:40 -0500 Subject: [PATCH] Get under test --- comfy/cli_args.py | 80 ++++++++++++++++++++++++---------- config.yaml | 2 +- conftest.py | 9 ++++ extra_model_paths.yaml.example | 25 ----------- main.py | 7 ++- pytest.ini | 3 ++ requirements.txt | 1 + test/test_cli_args.py | 79 +++++++++++++++++++++++++++++++++ 8 files changed, 155 insertions(+), 51 deletions(-) create mode 100644 conftest.py delete mode 100644 extra_model_paths.yaml.example create mode 100644 pytest.ini create mode 100644 test/test_cli_args.py diff --git a/comfy/cli_args.py b/comfy/cli_args.py index d8a3ddce0..40e948fd9 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -59,6 +59,15 @@ class AbstractOptionInfo: """ pass + def validate(self, config_options, cli_args): + """ + Modifies config_options to fix inconsistencies + + Example: config sets an enum value, cli args sets another, the flags + should be removed from config_options + """ + pass + class OptionInfo(AbstractOptionInfo): def __init__(self, name, **parser_args): self.name = name @@ -94,6 +103,9 @@ class OptionInfo(AbstractOptionInfo): def convert_to_file_option(self, parser, args): return args.get(self.name.replace("-", "_"), parser.get_default(self.name)) + def validate(self, config_options, cli_args): + pass + class OptionInfoFlag(AbstractOptionInfo): def __init__(self, name, **parser_args): self.name = name @@ -127,6 +139,9 @@ class OptionInfoFlag(AbstractOptionInfo): def convert_to_file_option(self, parser, args): return args.get(self.name.replace("-", "_"), parser.get_default(self.name) or False) + def validate(self, config_options, cli_args): + pass + class OptionInfoEnum(AbstractOptionInfo): def __init__(self, name, options, help=None, empty_help=None): self.name = name @@ -182,6 +197,12 @@ class OptionInfoEnum(AbstractOptionInfo): return option.name return None + def validate(self, config_options, cli_args): + set_by_cli = any(o for o in self.options if cli_args.get(o.option_name.replace("-", "_")) is not None) + if set_by_cli: + for option in self.options: + config_options[option.option_name.replace("-", "_")] = False + class OptionInfoEnumChoice: name: str option_name: str @@ -225,6 +246,9 @@ class OptionInfoRaw: def convert_to_file_option(self, parser, args): return args.get(self.name.replace("-", "_"), {}) + def validate(self, config_options, cli_args): + pass + # # Config options # @@ -238,7 +262,7 @@ CONFIG_OPTIONS = [ help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'."), ]), ("files", [ - OptionInfoRaw("extra-model-paths-config", help="Extra paths to scan for model files."), + OptionInfoRaw("extra-model-paths", help="Extra paths to scan for model files."), OptionInfo("output-directory", type=str, default=None, help="Set the ComfyUI output directory. Leave empty to use the default."), ]), ("behavior", [ @@ -321,8 +345,7 @@ def recursive_delete_comment_attribs(d): pass class ComfyConfigLoader: - def __init__(self, config_path): - self.config_path = config_path + def __init__(self): self.option_infos = CONFIG_OPTIONS self.parser = make_config_parser(self.option_infos) @@ -338,9 +361,8 @@ class ComfyConfigLoader: return defaults - def load_from_file(self, yaml_path): - with open(yaml_path, 'r') as stream: - raw_config = yaml.load(stream) + def load_from_string(self, raw_config): + raw_config = yaml.load(raw_config) config = {} root = raw_config.get("config", {}) @@ -410,9 +432,9 @@ class ComfyConfigLoader: config[category] = d return { "config": config } - def save_config(self, args): + def save_config(self, config_path, args): options = self.convert_args_to_options(args) - with open(self.config_path, 'w') as f: + with open(config_path, 'w') as f: yaml.dump(options, f) def get_cli_arguments(self, argv): @@ -426,29 +448,41 @@ class ComfyConfigLoader: suppressed_parser = make_config_parser(self.option_infos, suppress=True) return vars(suppressed_parser.parse_args(argv)) - def parse_args(self, argv): + def parse_args_with_file(self, yaml_path, argv): + if not os.path.isfile(yaml_path): + print(f"Warning: no config file at path '{yaml_path}', creating it") + raw_config = "{}" + else: + with open(yaml_path, 'r') as stream: + raw_config = stream.read() + return self.parse_args_with_string(raw_config, argv, save_config_file=yaml_path) + + def parse_args_with_string(self, config_string, argv, save_config_file=None): defaults = self.get_arg_defaults() - if not os.path.isfile(self.config_path): - print(f"Warning: no config file at path '{self.config_path}', creating it") - config_options = {} - else: - config_options = self.load_from_file(self.config_path) - + config_options = self.load_from_string(config_string) config_options = dict(merge_dicts(defaults, config_options)) - self.save_config(config_options) + if save_config_file: + self.save_config(save_config_file, config_options) cli_args = self.get_cli_arguments(argv) + for category, options in self.option_infos: + for option in options: + option.validate(config_options, cli_args) + args = dict(merge_dicts(config_options, cli_args)) return argparse.Namespace(**args) -# -# Load config and CLI args -# +args = {} -config_loader = ComfyConfigLoader(folder_paths.default_config_path) -args = config_loader.parse_args(sys.argv[1:]) +if "pytest" not in sys.modules: + # + # Load config and CLI args + # -if args.windows_standalone_build: - args.auto_launch = True + config_loader = ComfyConfigLoader() + args = config_loader.parse_args_with_file(folder_paths.default_config_path, sys.argv[1:]) + + if args.windows_standalone_build: + args.auto_launch = True diff --git a/config.yaml b/config.yaml index 176d88d7d..fff66f10a 100644 --- a/config.yaml +++ b/config.yaml @@ -15,7 +15,7 @@ config: files: # Extra paths to scan for model files. - extra_model_paths_config: + extra_model_paths: a111: base_path: path/to/stable-diffusion-webui/ checkpoints: models/Stable-diffusion diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000..bf2f19411 --- /dev/null +++ b/conftest.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python + +import sys, os + +# Make sure that the application source directory (this directory's parent) is +# on sys.path. + +here = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, here) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example deleted file mode 100644 index fa5418a68..000000000 --- a/extra_model_paths.yaml.example +++ /dev/null @@ -1,25 +0,0 @@ -#Rename this to extra_model_paths.yaml and ComfyUI will load it - -#config for a1111 ui -#all you have to do is change the base_path to where yours is installed -a111: - base_path: path/to/stable-diffusion-webui/ - - checkpoints: models/Stable-diffusion - configs: models/Stable-diffusion - vae: models/VAE - loras: models/Lora - upscale_models: | - models/ESRGAN - models/SwinIR - embeddings: embeddings - hypernetworks: models/hypernetworks - controlnet: models/ControlNet - -#other_ui: -# base_path: path/to/ui -# checkpoints: models/checkpoints -# gligen: models/gligen -# custom_nodes: path/custom_nodes - - diff --git a/main.py b/main.py index d5bdf1cae..9f9c0ad16 100644 --- a/main.py +++ b/main.py @@ -3,10 +3,13 @@ import itertools import os import shutil import threading +import pprint from comfy.cli_args import args import comfy.utils +print("Configuration: " + str(vars(args))) + if os.name == "nt": import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) @@ -82,8 +85,8 @@ if __name__ == "__main__": if os.path.isfile(extra_model_paths_config_path): load_extra_path_config_file(extra_model_paths_config_path) - if args.extra_model_paths_config: - load_extra_path_config(args.extra_model_paths_config) + if args.extra_model_paths: + load_extra_path_config(args.extra_model_paths) init_custom_nodes() server.add_routes() diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..2a3fef11c --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +addopts = -ra -q +testpaths = test diff --git a/requirements.txt b/requirements.txt index c2e32b978..9b78aad27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ pytorch_lightning aiohttp accelerate ruamel.yaml +pytest diff --git a/test/test_cli_args.py b/test/test_cli_args.py new file mode 100644 index 000000000..20ceddd6e --- /dev/null +++ b/test/test_cli_args.py @@ -0,0 +1,79 @@ +from comfy.cli_args import ComfyConfigLoader + + +def test_defaults(): + config = "{}" + argv = [] + + args = ComfyConfigLoader().parse_args_with_string(config, argv) + + assert args.listen == "127.0.0.1" + assert args.novram == False + + +def test_config(): + config = """ +config: + network: + listen: 0.0.0.0 +""" + argv = [] + + args = ComfyConfigLoader().parse_args_with_string(config, argv) + + assert args.listen == "0.0.0.0" + + +def test_cli_args_overrides_config(): + config = """ +config: + network: + listen: 0.0.0.0 +""" + argv = ["--listen", "192.168.1.100"] + + args = ComfyConfigLoader().parse_args_with_string(config, argv) + + assert args.listen == "192.168.1.100" + + +def test_config_enum(): + config = """ +config: + pytorch: + cross_attention: split +""" + argv = [] + + args = ComfyConfigLoader().parse_args_with_string(config, argv) + + assert args.use_split_cross_attention is True + assert args.use_pytorch_cross_attention is False + + +def test_config_enum_default(): + config = """ +config: + pytorch: + cross_attention: +""" + argv = [] + + args = ComfyConfigLoader().parse_args_with_string(config, argv) + + assert args.use_split_cross_attention is False + assert args.use_pytorch_cross_attention is False + + +def test_config_enum_exclusive(): + config = """ +config: + pytorch: + cross_attention: split +""" + argv = ["--use-pytorch-cross-attention"] + + args = ComfyConfigLoader().parse_args_with_string(config, argv) + + assert args.use_split_cross_attention is False + assert args.use_pytorch_cross_attention is True