Get under test

This commit is contained in:
space-nuko 2023-06-01 11:44:40 -05:00
parent 321a0e2fae
commit 143efb5900
8 changed files with 155 additions and 51 deletions

View File

@ -59,6 +59,15 @@ class AbstractOptionInfo:
"""
pass
def validate(self, config_options, cli_args):
"""
Modifies config_options to fix inconsistencies
Example: config sets an enum value, cli args sets another, the flags
should be removed from config_options
"""
pass
class OptionInfo(AbstractOptionInfo):
def __init__(self, name, **parser_args):
self.name = name
@ -94,6 +103,9 @@ class OptionInfo(AbstractOptionInfo):
def convert_to_file_option(self, parser, args):
return args.get(self.name.replace("-", "_"), parser.get_default(self.name))
def validate(self, config_options, cli_args):
pass
class OptionInfoFlag(AbstractOptionInfo):
def __init__(self, name, **parser_args):
self.name = name
@ -127,6 +139,9 @@ class OptionInfoFlag(AbstractOptionInfo):
def convert_to_file_option(self, parser, args):
return args.get(self.name.replace("-", "_"), parser.get_default(self.name) or False)
def validate(self, config_options, cli_args):
pass
class OptionInfoEnum(AbstractOptionInfo):
def __init__(self, name, options, help=None, empty_help=None):
self.name = name
@ -182,6 +197,12 @@ class OptionInfoEnum(AbstractOptionInfo):
return option.name
return None
def validate(self, config_options, cli_args):
set_by_cli = any(o for o in self.options if cli_args.get(o.option_name.replace("-", "_")) is not None)
if set_by_cli:
for option in self.options:
config_options[option.option_name.replace("-", "_")] = False
class OptionInfoEnumChoice:
name: str
option_name: str
@ -225,6 +246,9 @@ class OptionInfoRaw:
def convert_to_file_option(self, parser, args):
return args.get(self.name.replace("-", "_"), {})
def validate(self, config_options, cli_args):
pass
#
# Config options
#
@ -238,7 +262,7 @@ CONFIG_OPTIONS = [
help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'."),
]),
("files", [
OptionInfoRaw("extra-model-paths-config", help="Extra paths to scan for model files."),
OptionInfoRaw("extra-model-paths", help="Extra paths to scan for model files."),
OptionInfo("output-directory", type=str, default=None, help="Set the ComfyUI output directory. Leave empty to use the default."),
]),
("behavior", [
@ -321,8 +345,7 @@ def recursive_delete_comment_attribs(d):
pass
class ComfyConfigLoader:
def __init__(self, config_path):
self.config_path = config_path
def __init__(self):
self.option_infos = CONFIG_OPTIONS
self.parser = make_config_parser(self.option_infos)
@ -338,9 +361,8 @@ class ComfyConfigLoader:
return defaults
def load_from_file(self, yaml_path):
with open(yaml_path, 'r') as stream:
raw_config = yaml.load(stream)
def load_from_string(self, raw_config):
raw_config = yaml.load(raw_config)
config = {}
root = raw_config.get("config", {})
@ -410,9 +432,9 @@ class ComfyConfigLoader:
config[category] = d
return { "config": config }
def save_config(self, args):
def save_config(self, config_path, args):
options = self.convert_args_to_options(args)
with open(self.config_path, 'w') as f:
with open(config_path, 'w') as f:
yaml.dump(options, f)
def get_cli_arguments(self, argv):
@ -426,29 +448,41 @@ class ComfyConfigLoader:
suppressed_parser = make_config_parser(self.option_infos, suppress=True)
return vars(suppressed_parser.parse_args(argv))
def parse_args(self, argv):
def parse_args_with_file(self, yaml_path, argv):
if not os.path.isfile(yaml_path):
print(f"Warning: no config file at path '{yaml_path}', creating it")
raw_config = "{}"
else:
with open(yaml_path, 'r') as stream:
raw_config = stream.read()
return self.parse_args_with_string(raw_config, argv, save_config_file=yaml_path)
def parse_args_with_string(self, config_string, argv, save_config_file=None):
defaults = self.get_arg_defaults()
if not os.path.isfile(self.config_path):
print(f"Warning: no config file at path '{self.config_path}', creating it")
config_options = {}
else:
config_options = self.load_from_file(self.config_path)
config_options = self.load_from_string(config_string)
config_options = dict(merge_dicts(defaults, config_options))
self.save_config(config_options)
if save_config_file:
self.save_config(save_config_file, config_options)
cli_args = self.get_cli_arguments(argv)
for category, options in self.option_infos:
for option in options:
option.validate(config_options, cli_args)
args = dict(merge_dicts(config_options, cli_args))
return argparse.Namespace(**args)
#
# Load config and CLI args
#
args = {}
config_loader = ComfyConfigLoader(folder_paths.default_config_path)
args = config_loader.parse_args(sys.argv[1:])
if "pytest" not in sys.modules:
#
# Load config and CLI args
#
if args.windows_standalone_build:
args.auto_launch = True
config_loader = ComfyConfigLoader()
args = config_loader.parse_args_with_file(folder_paths.default_config_path, sys.argv[1:])
if args.windows_standalone_build:
args.auto_launch = True

View File

@ -15,7 +15,7 @@ config:
files:
# Extra paths to scan for model files.
extra_model_paths_config:
extra_model_paths:
a111:
base_path: path/to/stable-diffusion-webui/
checkpoints: models/Stable-diffusion

9
conftest.py Normal file
View File

@ -0,0 +1,9 @@
#!/usr/bin/env python
import sys, os
# Make sure that the application source directory (this directory's parent) is
# on sys.path.
here = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, here)

View File

@ -1,25 +0,0 @@
#Rename this to extra_model_paths.yaml and ComfyUI will load it
#config for a1111 ui
#all you have to do is change the base_path to where yours is installed
a111:
base_path: path/to/stable-diffusion-webui/
checkpoints: models/Stable-diffusion
configs: models/Stable-diffusion
vae: models/VAE
loras: models/Lora
upscale_models: |
models/ESRGAN
models/SwinIR
embeddings: embeddings
hypernetworks: models/hypernetworks
controlnet: models/ControlNet
#other_ui:
# base_path: path/to/ui
# checkpoints: models/checkpoints
# gligen: models/gligen
# custom_nodes: path/custom_nodes

View File

@ -3,10 +3,13 @@ import itertools
import os
import shutil
import threading
import pprint
from comfy.cli_args import args
import comfy.utils
print("Configuration: " + str(vars(args)))
if os.name == "nt":
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
@ -82,8 +85,8 @@ if __name__ == "__main__":
if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config_file(extra_model_paths_config_path)
if args.extra_model_paths_config:
load_extra_path_config(args.extra_model_paths_config)
if args.extra_model_paths:
load_extra_path_config(args.extra_model_paths)
init_custom_nodes()
server.add_routes()

3
pytest.ini Normal file
View File

@ -0,0 +1,3 @@
[pytest]
addopts = -ra -q
testpaths = test

View File

@ -9,3 +9,4 @@ pytorch_lightning
aiohttp
accelerate
ruamel.yaml
pytest

79
test/test_cli_args.py Normal file
View File

@ -0,0 +1,79 @@
from comfy.cli_args import ComfyConfigLoader
def test_defaults():
config = "{}"
argv = []
args = ComfyConfigLoader().parse_args_with_string(config, argv)
assert args.listen == "127.0.0.1"
assert args.novram == False
def test_config():
config = """
config:
network:
listen: 0.0.0.0
"""
argv = []
args = ComfyConfigLoader().parse_args_with_string(config, argv)
assert args.listen == "0.0.0.0"
def test_cli_args_overrides_config():
config = """
config:
network:
listen: 0.0.0.0
"""
argv = ["--listen", "192.168.1.100"]
args = ComfyConfigLoader().parse_args_with_string(config, argv)
assert args.listen == "192.168.1.100"
def test_config_enum():
config = """
config:
pytorch:
cross_attention: split
"""
argv = []
args = ComfyConfigLoader().parse_args_with_string(config, argv)
assert args.use_split_cross_attention is True
assert args.use_pytorch_cross_attention is False
def test_config_enum_default():
config = """
config:
pytorch:
cross_attention:
"""
argv = []
args = ComfyConfigLoader().parse_args_with_string(config, argv)
assert args.use_split_cross_attention is False
assert args.use_pytorch_cross_attention is False
def test_config_enum_exclusive():
config = """
config:
pytorch:
cross_attention: split
"""
argv = ["--use-pytorch-cross-attention"]
args = ComfyConfigLoader().parse_args_with_string(config, argv)
assert args.use_split_cross_attention is False
assert args.use_pytorch_cross_attention is True