mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 07:10:52 +08:00
Get under test
This commit is contained in:
parent
321a0e2fae
commit
143efb5900
@ -59,6 +59,15 @@ class AbstractOptionInfo:
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate(self, config_options, cli_args):
|
||||
"""
|
||||
Modifies config_options to fix inconsistencies
|
||||
|
||||
Example: config sets an enum value, cli args sets another, the flags
|
||||
should be removed from config_options
|
||||
"""
|
||||
pass
|
||||
|
||||
class OptionInfo(AbstractOptionInfo):
|
||||
def __init__(self, name, **parser_args):
|
||||
self.name = name
|
||||
@ -94,6 +103,9 @@ class OptionInfo(AbstractOptionInfo):
|
||||
def convert_to_file_option(self, parser, args):
|
||||
return args.get(self.name.replace("-", "_"), parser.get_default(self.name))
|
||||
|
||||
def validate(self, config_options, cli_args):
|
||||
pass
|
||||
|
||||
class OptionInfoFlag(AbstractOptionInfo):
|
||||
def __init__(self, name, **parser_args):
|
||||
self.name = name
|
||||
@ -127,6 +139,9 @@ class OptionInfoFlag(AbstractOptionInfo):
|
||||
def convert_to_file_option(self, parser, args):
|
||||
return args.get(self.name.replace("-", "_"), parser.get_default(self.name) or False)
|
||||
|
||||
def validate(self, config_options, cli_args):
|
||||
pass
|
||||
|
||||
class OptionInfoEnum(AbstractOptionInfo):
|
||||
def __init__(self, name, options, help=None, empty_help=None):
|
||||
self.name = name
|
||||
@ -182,6 +197,12 @@ class OptionInfoEnum(AbstractOptionInfo):
|
||||
return option.name
|
||||
return None
|
||||
|
||||
def validate(self, config_options, cli_args):
|
||||
set_by_cli = any(o for o in self.options if cli_args.get(o.option_name.replace("-", "_")) is not None)
|
||||
if set_by_cli:
|
||||
for option in self.options:
|
||||
config_options[option.option_name.replace("-", "_")] = False
|
||||
|
||||
class OptionInfoEnumChoice:
|
||||
name: str
|
||||
option_name: str
|
||||
@ -225,6 +246,9 @@ class OptionInfoRaw:
|
||||
def convert_to_file_option(self, parser, args):
|
||||
return args.get(self.name.replace("-", "_"), {})
|
||||
|
||||
def validate(self, config_options, cli_args):
|
||||
pass
|
||||
|
||||
#
|
||||
# Config options
|
||||
#
|
||||
@ -238,7 +262,7 @@ CONFIG_OPTIONS = [
|
||||
help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'."),
|
||||
]),
|
||||
("files", [
|
||||
OptionInfoRaw("extra-model-paths-config", help="Extra paths to scan for model files."),
|
||||
OptionInfoRaw("extra-model-paths", help="Extra paths to scan for model files."),
|
||||
OptionInfo("output-directory", type=str, default=None, help="Set the ComfyUI output directory. Leave empty to use the default."),
|
||||
]),
|
||||
("behavior", [
|
||||
@ -321,8 +345,7 @@ def recursive_delete_comment_attribs(d):
|
||||
pass
|
||||
|
||||
class ComfyConfigLoader:
|
||||
def __init__(self, config_path):
|
||||
self.config_path = config_path
|
||||
def __init__(self):
|
||||
self.option_infos = CONFIG_OPTIONS
|
||||
self.parser = make_config_parser(self.option_infos)
|
||||
|
||||
@ -338,9 +361,8 @@ class ComfyConfigLoader:
|
||||
|
||||
return defaults
|
||||
|
||||
def load_from_file(self, yaml_path):
|
||||
with open(yaml_path, 'r') as stream:
|
||||
raw_config = yaml.load(stream)
|
||||
def load_from_string(self, raw_config):
|
||||
raw_config = yaml.load(raw_config)
|
||||
|
||||
config = {}
|
||||
root = raw_config.get("config", {})
|
||||
@ -410,9 +432,9 @@ class ComfyConfigLoader:
|
||||
config[category] = d
|
||||
return { "config": config }
|
||||
|
||||
def save_config(self, args):
|
||||
def save_config(self, config_path, args):
|
||||
options = self.convert_args_to_options(args)
|
||||
with open(self.config_path, 'w') as f:
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(options, f)
|
||||
|
||||
def get_cli_arguments(self, argv):
|
||||
@ -426,29 +448,41 @@ class ComfyConfigLoader:
|
||||
suppressed_parser = make_config_parser(self.option_infos, suppress=True)
|
||||
return vars(suppressed_parser.parse_args(argv))
|
||||
|
||||
def parse_args(self, argv):
|
||||
def parse_args_with_file(self, yaml_path, argv):
|
||||
if not os.path.isfile(yaml_path):
|
||||
print(f"Warning: no config file at path '{yaml_path}', creating it")
|
||||
raw_config = "{}"
|
||||
else:
|
||||
with open(yaml_path, 'r') as stream:
|
||||
raw_config = stream.read()
|
||||
return self.parse_args_with_string(raw_config, argv, save_config_file=yaml_path)
|
||||
|
||||
def parse_args_with_string(self, config_string, argv, save_config_file=None):
|
||||
defaults = self.get_arg_defaults()
|
||||
|
||||
if not os.path.isfile(self.config_path):
|
||||
print(f"Warning: no config file at path '{self.config_path}', creating it")
|
||||
config_options = {}
|
||||
else:
|
||||
config_options = self.load_from_file(self.config_path)
|
||||
|
||||
config_options = self.load_from_string(config_string)
|
||||
config_options = dict(merge_dicts(defaults, config_options))
|
||||
self.save_config(config_options)
|
||||
if save_config_file:
|
||||
self.save_config(save_config_file, config_options)
|
||||
|
||||
cli_args = self.get_cli_arguments(argv)
|
||||
|
||||
for category, options in self.option_infos:
|
||||
for option in options:
|
||||
option.validate(config_options, cli_args)
|
||||
|
||||
args = dict(merge_dicts(config_options, cli_args))
|
||||
return argparse.Namespace(**args)
|
||||
|
||||
#
|
||||
# Load config and CLI args
|
||||
#
|
||||
args = {}
|
||||
|
||||
config_loader = ComfyConfigLoader(folder_paths.default_config_path)
|
||||
args = config_loader.parse_args(sys.argv[1:])
|
||||
if "pytest" not in sys.modules:
|
||||
#
|
||||
# Load config and CLI args
|
||||
#
|
||||
|
||||
if args.windows_standalone_build:
|
||||
args.auto_launch = True
|
||||
config_loader = ComfyConfigLoader()
|
||||
args = config_loader.parse_args_with_file(folder_paths.default_config_path, sys.argv[1:])
|
||||
|
||||
if args.windows_standalone_build:
|
||||
args.auto_launch = True
|
||||
|
||||
@ -15,7 +15,7 @@ config:
|
||||
files:
|
||||
|
||||
# Extra paths to scan for model files.
|
||||
extra_model_paths_config:
|
||||
extra_model_paths:
|
||||
a111:
|
||||
base_path: path/to/stable-diffusion-webui/
|
||||
checkpoints: models/Stable-diffusion
|
||||
|
||||
9
conftest.py
Normal file
9
conftest.py
Normal file
@ -0,0 +1,9 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import sys, os
|
||||
|
||||
# Make sure that the application source directory (this directory's parent) is
|
||||
# on sys.path.
|
||||
|
||||
here = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, here)
|
||||
@ -1,25 +0,0 @@
|
||||
#Rename this to extra_model_paths.yaml and ComfyUI will load it
|
||||
|
||||
#config for a1111 ui
|
||||
#all you have to do is change the base_path to where yours is installed
|
||||
a111:
|
||||
base_path: path/to/stable-diffusion-webui/
|
||||
|
||||
checkpoints: models/Stable-diffusion
|
||||
configs: models/Stable-diffusion
|
||||
vae: models/VAE
|
||||
loras: models/Lora
|
||||
upscale_models: |
|
||||
models/ESRGAN
|
||||
models/SwinIR
|
||||
embeddings: embeddings
|
||||
hypernetworks: models/hypernetworks
|
||||
controlnet: models/ControlNet
|
||||
|
||||
#other_ui:
|
||||
# base_path: path/to/ui
|
||||
# checkpoints: models/checkpoints
|
||||
# gligen: models/gligen
|
||||
# custom_nodes: path/custom_nodes
|
||||
|
||||
|
||||
7
main.py
7
main.py
@ -3,10 +3,13 @@ import itertools
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
import pprint
|
||||
|
||||
from comfy.cli_args import args
|
||||
import comfy.utils
|
||||
|
||||
print("Configuration: " + str(vars(args)))
|
||||
|
||||
if os.name == "nt":
|
||||
import logging
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
@ -82,8 +85,8 @@ if __name__ == "__main__":
|
||||
if os.path.isfile(extra_model_paths_config_path):
|
||||
load_extra_path_config_file(extra_model_paths_config_path)
|
||||
|
||||
if args.extra_model_paths_config:
|
||||
load_extra_path_config(args.extra_model_paths_config)
|
||||
if args.extra_model_paths:
|
||||
load_extra_path_config(args.extra_model_paths)
|
||||
|
||||
init_custom_nodes()
|
||||
server.add_routes()
|
||||
|
||||
3
pytest.ini
Normal file
3
pytest.ini
Normal file
@ -0,0 +1,3 @@
|
||||
[pytest]
|
||||
addopts = -ra -q
|
||||
testpaths = test
|
||||
@ -9,3 +9,4 @@ pytorch_lightning
|
||||
aiohttp
|
||||
accelerate
|
||||
ruamel.yaml
|
||||
pytest
|
||||
|
||||
79
test/test_cli_args.py
Normal file
79
test/test_cli_args.py
Normal file
@ -0,0 +1,79 @@
|
||||
from comfy.cli_args import ComfyConfigLoader
|
||||
|
||||
|
||||
def test_defaults():
|
||||
config = "{}"
|
||||
argv = []
|
||||
|
||||
args = ComfyConfigLoader().parse_args_with_string(config, argv)
|
||||
|
||||
assert args.listen == "127.0.0.1"
|
||||
assert args.novram == False
|
||||
|
||||
|
||||
def test_config():
|
||||
config = """
|
||||
config:
|
||||
network:
|
||||
listen: 0.0.0.0
|
||||
"""
|
||||
argv = []
|
||||
|
||||
args = ComfyConfigLoader().parse_args_with_string(config, argv)
|
||||
|
||||
assert args.listen == "0.0.0.0"
|
||||
|
||||
|
||||
def test_cli_args_overrides_config():
|
||||
config = """
|
||||
config:
|
||||
network:
|
||||
listen: 0.0.0.0
|
||||
"""
|
||||
argv = ["--listen", "192.168.1.100"]
|
||||
|
||||
args = ComfyConfigLoader().parse_args_with_string(config, argv)
|
||||
|
||||
assert args.listen == "192.168.1.100"
|
||||
|
||||
|
||||
def test_config_enum():
|
||||
config = """
|
||||
config:
|
||||
pytorch:
|
||||
cross_attention: split
|
||||
"""
|
||||
argv = []
|
||||
|
||||
args = ComfyConfigLoader().parse_args_with_string(config, argv)
|
||||
|
||||
assert args.use_split_cross_attention is True
|
||||
assert args.use_pytorch_cross_attention is False
|
||||
|
||||
|
||||
def test_config_enum_default():
|
||||
config = """
|
||||
config:
|
||||
pytorch:
|
||||
cross_attention:
|
||||
"""
|
||||
argv = []
|
||||
|
||||
args = ComfyConfigLoader().parse_args_with_string(config, argv)
|
||||
|
||||
assert args.use_split_cross_attention is False
|
||||
assert args.use_pytorch_cross_attention is False
|
||||
|
||||
|
||||
def test_config_enum_exclusive():
|
||||
config = """
|
||||
config:
|
||||
pytorch:
|
||||
cross_attention: split
|
||||
"""
|
||||
argv = ["--use-pytorch-cross-attention"]
|
||||
|
||||
args = ComfyConfigLoader().parse_args_with_string(config, argv)
|
||||
|
||||
assert args.use_split_cross_attention is False
|
||||
assert args.use_pytorch_cross_attention is True
|
||||
Loading…
Reference in New Issue
Block a user