diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cc4709f70..271c43431 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -1,36 +1,505 @@ +import sys +import os.path +import pprint import argparse +import ruamel.yaml +import folder_paths +import copy -parser = argparse.ArgumentParser() +yaml = ruamel.yaml.YAML() +yaml.default_flow_style = False +yaml.sort_keys = False +CM = ruamel.yaml.comments.CommentedMap -parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") -parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") -parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") -parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") -parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") -parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") -parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") -parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") -parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") -parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") +class AbstractOptionInfo: + """ + A single option that can be saved to the config file YAML. Can potentially + comprise more than one command line arg/flag like mutually exclusive flag + groups being condensed into an enum in the config file + """ -attn_group = parser.add_mutually_exclusive_group() -attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") -attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") + name: str + raw_output: bool -parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") + def add_argument(self, parser, suppress): + """ + Adds an argument to the argparse parser + """ + pass -vram_group = parser.add_mutually_exclusive_group() -vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") -vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.") -vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.") -vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") -vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") + def get_arg_defaults(self, parser): + """ + Returns the expected argparse namespaced options as a dictionary from + querying the parser's default value + """ + pass -parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") -parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") -parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") + def get_help(self): + pass -args = parser.parse_args() + def convert_to_args_array(self, value): + """ + Interprets this option and a value as a string array of argstrings -if args.windows_standalone_build: - args.auto_launch = True + If it's a flag returns ["--flag"] for True and [] for False, otherwise + can return ["--option", str(value)] or similar + + Alternatively if raw_output is True the parser will skip the parsing + step and use the value as if it were returned from parse_known_args() + """ + pass + + def convert_to_file_option(self, parser, args): + """ + Converts a portion of the args to a value to be serialized to the config + file + + As an example vram options are mutually exclusive, so it's made to look + like an enum in the config file. So having a --lowvram flag in the args + gets translated to "vram: 'lowvram'" in YAML + """ + pass + + def validate(self, config_options, cli_args): + """ + Modifies config_options to fix inconsistencies + + Example: config sets an enum value, cli args sets another, the flags + should be removed from config_options + """ + pass + +class OptionInfo(AbstractOptionInfo): + def __init__(self, name, **parser_args): + self.name = name + self.parser_args = parser_args + self.raw_output = False + + def __repr__(self): + return f'OptionInfo(\'{self.name}\', {pprint.pformat(self.parser_args)})' + + def add_argument(self, parser, suppress=False): + parser_args = dict(self.parser_args) + if suppress: + parser_args["default"] = argparse.SUPPRESS + parser.add_argument(f"--{self.name}", **parser_args) + + def get_arg_defaults(self, parser): + return { self.name: parser.get_default(self.name) } + + def get_help(self): + help = self.parser_args.get("help") + if help is None: + return None + type = self.parser_args.get("type") + if type: + help += f"\nType: {type.__name__}" + return help + + def convert_to_args_array(self, value): + if value is not None: + return [f"--{self.name}", str(value)] + return [] + + def convert_to_file_option(self, parser, args): + return args.get(self.name.replace("-", "_"), parser.get_default(self.name)) + + def validate(self, config_options, cli_args): + pass + +class OptionInfoFlag(AbstractOptionInfo): + def __init__(self, name, **parser_args): + self.name = name + self.parser_args = parser_args + self.raw_output = False + + def __repr__(self): + return f'OptionInfoFlag(\'{self.name}\', {pprint.pformat(self.parser_args)})' + + def add_argument(self, parser, suppress): + parser_args = dict(self.parser_args) + if suppress: + parser_args["default"] = argparse.SUPPRESS + parser.add_argument(f"--{self.name}", action="store_true", **parser_args) + + def get_arg_defaults(self, parser): + return { self.name.replace("-", "_"): parser.get_default(self.name) or False } + + def get_help(self): + help = self.parser_args.get("help") + if help is None: + return None + help += "\nType: bool" + return help + + def convert_to_args_array(self, value): + if value: + return [f"--{self.name}"] + return [] + + def convert_to_file_option(self, parser, args): + return args.get(self.name.replace("-", "_"), parser.get_default(self.name) or False) + + def validate(self, config_options, cli_args): + pass + +class OptionInfoEnum(AbstractOptionInfo): + def __init__(self, name, options, help=None, empty_help=None): + self.name = name + self.options = options + self.help = help + self.empty_help = empty_help + self.parser_args = {} + self.raw_output = True + + def __repr__(self): + return f'OptionInfoEnum(\'{self.name}\', {pprint.pformat(self.options)})' + + def add_argument(self, parser, suppress): + group = parser.add_mutually_exclusive_group() + default = None + if suppress: + default = argparse.SUPPRESS + for option in self.options: + group.add_argument(f"--{option.option_name}", action="store_true", help=option.help, default=default) + + def get_arg_defaults(self, parser): + result = {} + for option in self.options: + result[option.option_name.replace("-", "_")] = False + return result + + def get_help(self): + if self.help is None: + return None + help = self.help + "\nChoices:" + + help += "\n - (empty)" + if self.empty_help is not None: + help += f": {self.empty_help}" + + for option in self.options: + help += f"\n - {option.name}" + if option.help: + help += f": {option.help}" + + return help + + def convert_to_args_array(self, file_value): + affected_options = [o.option_name.replace("-", "_") for o in self.options] + for option in self.options: + if option.name == file_value: + return ({ option.option_name: True }, affected_options) + return ({}, affected_options) + + def convert_to_file_option(self, parser, args): + for option in self.options: + if args.get(option.option_name.replace("-", "_")) is True: + return option.name + return None + + def validate(self, config_options, cli_args): + set_by_cli = any(o for o in self.options if cli_args.get(o.option_name.replace("-", "_")) is not None) + if set_by_cli: + for option in self.options: + config_options[option.option_name.replace("-", "_")] = False + +class OptionInfoEnumChoice: + name: str + option_name: str + help: str + + def __init__(self, name, option_name=None, help=None): + self.name = name + if option_name is None: + option_name = self.name + self.option_name = option_name + if help is None: + help = "" + self.help = help + + def __repr__(self): + return f'OptionInfoEnumChoice(\'{self.name}\', \'{self.option_name}\', \'{self.help}\')' + +class OptionInfoRaw: + """ + Raw YAML input and output, ignores argparse entirely + """ + + def __init__(self, name, help=None, default=None): + self.name = name + self.help = help + self.default = default or {} + self.parser_args = {} + self.raw_output = True + + def add_argument(self, parser, suppress): + pass + + def get_help(self): + return self.help + + def get_arg_defaults(self, parser): + return { self.name: copy.copy(self.default) } + + def convert_to_args_array(self, value): + return { self.name: value } + + def convert_to_file_option(self, parser, args): + return args.get(self.name.replace("-", "_"), copy.copy(self.default)) + + def validate(self, config_options, cli_args): + pass + +# +# Config options +# + +DEFAULT_EXTRA_MODEL_PATHS_CONFIG = yaml.load(""" +a1111: + base_path: path/to/stable-diffusion-webui/ + checkpoints: models/Stable-diffusion + configs: models/Stable-diffusion + vae: models/VAE + loras: models/Lora + upscale_models: | + models/ESRGAN + models/SwinIR + embeddings: embeddings + hypernetworks: models/hypernetworks + controlnet: models/ControlNet +""") + +CONFIG_OPTIONS = [ + ("network", [ + OptionInfo("listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", + help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)"), + OptionInfo("port", type=int, default=8188, help="Set the listen port."), + OptionInfo("enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", + help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'."), + ]), + ("files", [ + OptionInfoRaw("extra-model-paths", help="Extra paths to scan for model files.", default=DEFAULT_EXTRA_MODEL_PATHS_CONFIG), + OptionInfo("output-directory", type=str, default=None, help="Set the ComfyUI output directory. Leave empty to use the default."), + ]), + ("behavior", [ + OptionInfoFlag("auto-launch", + help="Automatically launch ComfyUI in the default browser."), + OptionInfoFlag("dont-print-server", + help="Don't print server output."), + OptionInfoFlag("quick-test-for-ci", + help="Quick test for CI."), + OptionInfoFlag("windows-standalone-build", + help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup)."), + ]), + ("pytorch", [ + OptionInfo("cuda-device", type=int, default=None, metavar="DEVICE_ID", + help="Set the id of the cuda device this instance will use, or leave empty to autodetect."), + OptionInfoFlag("dont-upcast-attention", + help="Disable upcasting of attention. Can boost speed but increase the chances of black images."), + OptionInfoFlag("force-fp32", + help="Force fp32 (If this makes your GPU work better please report it)."), + OptionInfo("directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, + help="Use torch-directml."), + OptionInfoEnum("cross-attention", [ + OptionInfoEnumChoice("split", option_name="use-split-cross-attention", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used."), + OptionInfoEnumChoice("pytorch", option_name="use-pytorch-cross-attention", help="Use the new pytorch 2.0 cross attention function."), + ], help="Type of cross attention to use", empty_help="Don't use cross-attention."), + OptionInfoFlag("disable-xformers", + help="Disable xformers."), + OptionInfoEnum("vram", [ + OptionInfoEnumChoice("highvram", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory."), + OptionInfoEnumChoice("normalvram", help="Used to force normal vram use if lowvram gets automatically enabled."), + OptionInfoEnumChoice("lowvram", help="Split the unet in parts to use less vram."), + OptionInfoEnumChoice("novram", help="When lowvram isn't enough."), + OptionInfoEnumChoice("cpu", help="To use the CPU for everything (slow).") + ], help="Determines how VRAM is used.", empty_help="Autodetect the optional VRAM settings based on hardware.") + ]) +] + +# +# Config parser +# + +def make_config_parser(option_infos, suppress=False): + parser = argparse.ArgumentParser() + + for category, options in option_infos: + for option in options: + option.add_argument(parser, suppress) + + return parser + +def merge_dicts(dict1, dict2): + for k in set(dict1.keys()).union(dict2.keys()): + if k in dict1 and k in dict2: + if isinstance(dict1[k], dict) and isinstance(dict2[k], dict): + yield (k, dict(merge_dicts(dict1[k], dict2[k]))) + else: + # If one of the values is not a dict, you can't continue merging it. + # Value from second dict overrides one in first and we move on. + yield (k, dict2[k]) + # Alternatively, replace this with exception raiser to alert you of value conflicts + elif k in dict1: + yield (k, dict1[k]) + else: + yield (k, dict2[k]) + +def recursive_delete_comment_attribs(d): + if isinstance(d, dict): + for k, v in d.items(): + recursive_delete_comment_attribs(k) + recursive_delete_comment_attribs(v) + elif isinstance(d, list): + for elem in d: + recursive_delete_comment_attribs(elem) + try: + # literal scalarstring might have comment associated with them + attr = 'comment' if isinstance(d, ruamel.yaml.scalarstring.ScalarString) \ + else ruamel.yaml.comments.Comment.attrib + delattr(d, attr) + except AttributeError: + pass + +class ComfyConfigLoader: + def __init__(self): + self.option_infos = CONFIG_OPTIONS + self.parser = make_config_parser(self.option_infos) + + def get_arg_defaults(self): + defaults = {} + + for category, options in self.option_infos: + for option in options: + arg_defaults = option.get_arg_defaults(self.parser) + for k, v in arg_defaults.items(): + k = k.replace('-', '_') + defaults[k] = v + + return defaults + + def load_from_string(self, raw_config): + raw_config = yaml.load(raw_config) or {} + + config = {} + root = raw_config.get("config", {}) + + for category, options in self.option_infos: + if category in root: + from_file = root[category] + + known_args = [] + for k, v in from_file.items(): + kebab_k = k.replace("_", "-") + option_info = next((o for o in options if o.name == kebab_k), None) + if option_info is not None: + known_args = option_info.convert_to_args_array(v) + affected_options = [k] + if isinstance(known_args, tuple): + # Enum options can affect more than one flag in the + # CLI args, so have to check multiple items in the + # namespace argparse returns + affected_options = known_args[1] + known_args = known_args[0] + + if option_info.raw_output: + converted = {} + for k, v in known_args.items(): + converted[k.replace("-", "_")] = v + parsed = argparse.Namespace(**converted) + rest = None + else: + parsed, rest = self.parser.parse_known_args(known_args) + + parsed_vars = vars(parsed) + + # parse_known_args returns *all* options configured even + # if they're not found in the argstring. So have to pick + # out only the args affected by this option. + for ka in affected_options: + underscore_ka = ka.replace("-", "_") + item = parsed_vars.get(underscore_ka) + if item is not None: + config[ka] = item + + if rest: + print(f"Warning: unparsed args - {pprint.pformat(rest)}") + + return config + + def convert_args_to_options(self, args): + if isinstance(args, argparse.Namespace): + args = vars(args) + + # strip previous YAML comments + recursive_delete_comment_attribs(args) + + config = {} + for category, options in self.option_infos: + d = CM() + first = True + for option in options: + k = option.name.replace('-', '_') + d[k] = option.convert_to_file_option(self.parser, args) + help = option.get_help() + if help is not None: + help_string = "\n" + help + d.yaml_set_comment_before_after_key(k, help_string, indent=4) + first = False + config[category] = d + return { "config": config } + + def save_config(self, config_path, args): + options = self.convert_args_to_options(args) + with open(config_path, 'w') as f: + yaml.dump(options, f) + + def get_cli_arguments(self, argv): + # first parse regularly and exit if an error is found + self.parser.parse_args(argv) + + # now create another parser that suppresses missing arguments (not + # user-specified) such that only the arguments passed will be put in the + # namespace. Without this every argument set in the config will be + # overridden because they're all present in the argparse.Namespace + suppressed_parser = make_config_parser(self.option_infos, suppress=True) + return vars(suppressed_parser.parse_args(argv)) + + def parse_args_with_file(self, yaml_path, argv): + if not os.path.isfile(yaml_path): + print(f"Warning: no config file at path '{yaml_path}', creating it") + raw_config = "{}" + else: + with open(yaml_path, 'r') as stream: + raw_config = stream.read() + return self.parse_args_with_string(raw_config, argv, save_config_file=yaml_path) + + def parse_args_with_string(self, config_string, argv, save_config_file=None): + defaults = self.get_arg_defaults() + + config_options = self.load_from_string(config_string) + config_options = dict(merge_dicts(defaults, config_options)) + if save_config_file: + self.save_config(save_config_file, config_options) + + cli_args = self.get_cli_arguments(argv) + + for category, options in self.option_infos: + for option in options: + option.validate(config_options, cli_args) + + args = dict(merge_dicts(config_options, cli_args)) + return argparse.Namespace(**args) + +args = {} + +if "pytest" not in sys.modules: + # + # Load config and CLI args + # + + config_loader = ComfyConfigLoader() + args = config_loader.parse_args_with_file(folder_paths.default_config_path, sys.argv[1:]) + + if args.windows_standalone_build: + args.auto_launch = True diff --git a/config.yaml b/config.yaml new file mode 100644 index 000000000..6c771d99f --- /dev/null +++ b/config.yaml @@ -0,0 +1,89 @@ +config: + network: + + # Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all) + # Type: str + listen: 127.0.0.1 + + # Set the listen port. + # Type: int + port: 8188 + + # Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'. + # Type: str + enable_cors_header: + files: + + # Extra paths to scan for model files. + extra_model_paths: + a1111: + hypernetworks: models/hypernetworks + base_path: path/to/stable-diffusion-webui/ + embeddings: embeddings + controlnet: models/ControlNet + configs: models/Stable-diffusion + loras: models/Lora + vae: models/VAE + checkpoints: models/Stable-diffusion + upscale_models: | + models/ESRGAN + models/SwinIR + + # Set the ComfyUI output directory. Leave empty to use the default. + # Type: str + output_directory: + behavior: + + # Automatically launch ComfyUI in the default browser. + # Type: bool + auto_launch: false + + # Don't print server output. + # Type: bool + dont_print_server: false + + # Quick test for CI. + # Type: bool + quick_test_for_ci: false + + # Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup). + # Type: bool + windows_standalone_build: false + pytorch: + + # Set the id of the cuda device this instance will use, or leave empty to autodetect. + # Type: int + cuda_device: + + # Disable upcasting of attention. Can boost speed but increase the chances of black images. + # Type: bool + dont_upcast_attention: false + + # Force fp32 (If this makes your GPU work better please report it). + # Type: bool + force_fp32: false + + # Use torch-directml. + # Type: int + directml: + + # Type of cross attention to use + # Choices: + # - (empty): Don't use cross-attention. + # - split: Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used. + # - pytorch: Use the new pytorch 2.0 cross attention function. + cross_attention: + + # Disable xformers. + # Type: bool + disable_xformers: false + + # Determines how VRAM is used. + # Choices: + # - (empty): Autodetect the optional VRAM settings based on hardware. + # - highvram: By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory. + # - normalvram: Used to force normal vram use if lowvram gets automatically enabled. + # - lowvram: Split the unet in parts to use less vram. + # - novram: When lowvram isn't enough. + # - cpu: To use the CPU for everything (slow). + vram: diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000..bf2f19411 --- /dev/null +++ b/conftest.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python + +import sys, os + +# Make sure that the application source directory (this directory's parent) is +# on sys.path. + +here = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, here) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example deleted file mode 100644 index fa5418a68..000000000 --- a/extra_model_paths.yaml.example +++ /dev/null @@ -1,25 +0,0 @@ -#Rename this to extra_model_paths.yaml and ComfyUI will load it - -#config for a1111 ui -#all you have to do is change the base_path to where yours is installed -a111: - base_path: path/to/stable-diffusion-webui/ - - checkpoints: models/Stable-diffusion - configs: models/Stable-diffusion - vae: models/VAE - loras: models/Lora - upscale_models: | - models/ESRGAN - models/SwinIR - embeddings: embeddings - hypernetworks: models/hypernetworks - controlnet: models/ControlNet - -#other_ui: -# base_path: path/to/ui -# checkpoints: models/checkpoints -# gligen: models/gligen -# custom_nodes: path/custom_nodes - - diff --git a/folder_paths.py b/folder_paths.py index a1bf1444d..f42a040ee 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -7,6 +7,7 @@ supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) folder_names_and_paths = {} base_path = os.path.dirname(os.path.realpath(__file__)) +default_config_path = os.path.join(base_path, "config.yaml") models_dir = os.path.join(base_path, "models") folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions) folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"]) diff --git a/main.py b/main.py index 50d3b9a62..9f9c0ad16 100644 --- a/main.py +++ b/main.py @@ -3,10 +3,13 @@ import itertools import os import shutil import threading +import pprint from comfy.cli_args import args import comfy.utils +print("Configuration: " + str(vars(args))) + if os.name == "nt": import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) @@ -49,16 +52,17 @@ def cleanup_temp(): if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) -def load_extra_path_config(yaml_path): +def load_extra_path_config_file(yaml_path): with open(yaml_path, 'r') as stream: config = yaml.safe_load(stream) + load_extra_path_config(config) + +def load_extra_path_config(config): for c in config: conf = config[c] if conf is None: continue - base_path = None - if "base_path" in conf: - base_path = conf.pop("base_path") + base_path = conf.get("base_path", None) for x in conf: for y in conf[x].split("\n"): if len(y) == 0: @@ -79,11 +83,10 @@ if __name__ == "__main__": extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") if os.path.isfile(extra_model_paths_config_path): - load_extra_path_config(extra_model_paths_config_path) + load_extra_path_config_file(extra_model_paths_config_path) - if args.extra_model_paths_config: - for config_path in itertools.chain(*args.extra_model_paths_config): - load_extra_path_config(config_path) + if args.extra_model_paths: + load_extra_path_config(args.extra_model_paths) init_custom_nodes() server.add_routes() diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..2a3fef11c --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +addopts = -ra -q +testpaths = test diff --git a/requirements.txt b/requirements.txt index 0527b31df..f464c9379 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,5 @@ pytorch_lightning aiohttp accelerate pyyaml +ruamel.yaml +pytest diff --git a/test/test_cli_args.py b/test/test_cli_args.py new file mode 100644 index 000000000..1e3ac26fc --- /dev/null +++ b/test/test_cli_args.py @@ -0,0 +1,83 @@ +from comfy.cli_args import ComfyConfigLoader + + +def test_defaults(): + config = "{}" + argv = [] + + args = ComfyConfigLoader().parse_args_with_string(config, argv) + + assert args.listen == "127.0.0.1" + assert args.novram == False + + extra_paths = args.extra_model_paths.get("a1111") + assert extra_paths is not None + assert extra_paths.get("base_path") == "path/to/stable-diffusion-webui/" + + +def test_config(): + config = """ +config: + network: + listen: 0.0.0.0 +""" + argv = [] + + args = ComfyConfigLoader().parse_args_with_string(config, argv) + + assert args.listen == "0.0.0.0" + + +def test_cli_args_overrides_config(): + config = """ +config: + network: + listen: 0.0.0.0 +""" + argv = ["--listen", "192.168.1.100"] + + args = ComfyConfigLoader().parse_args_with_string(config, argv) + + assert args.listen == "192.168.1.100" + + +def test_config_enum(): + config = """ +config: + pytorch: + cross_attention: split +""" + argv = [] + + args = ComfyConfigLoader().parse_args_with_string(config, argv) + + assert args.use_split_cross_attention is True + assert args.use_pytorch_cross_attention is False + + +def test_config_enum_default(): + config = """ +config: + pytorch: + cross_attention: +""" + argv = [] + + args = ComfyConfigLoader().parse_args_with_string(config, argv) + + assert args.use_split_cross_attention is False + assert args.use_pytorch_cross_attention is False + + +def test_config_enum_exclusive(): + config = """ +config: + pytorch: + cross_attention: split +""" + argv = ["--use-pytorch-cross-attention"] + + args = ComfyConfigLoader().parse_args_with_string(config, argv) + + assert args.use_split_cross_attention is False + assert args.use_pytorch_cross_attention is True