mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 16:50:57 +08:00
Get under test
This commit is contained in:
parent
321a0e2fae
commit
143efb5900
@ -59,6 +59,15 @@ class AbstractOptionInfo:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def validate(self, config_options, cli_args):
|
||||||
|
"""
|
||||||
|
Modifies config_options to fix inconsistencies
|
||||||
|
|
||||||
|
Example: config sets an enum value, cli args sets another, the flags
|
||||||
|
should be removed from config_options
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
class OptionInfo(AbstractOptionInfo):
|
class OptionInfo(AbstractOptionInfo):
|
||||||
def __init__(self, name, **parser_args):
|
def __init__(self, name, **parser_args):
|
||||||
self.name = name
|
self.name = name
|
||||||
@ -94,6 +103,9 @@ class OptionInfo(AbstractOptionInfo):
|
|||||||
def convert_to_file_option(self, parser, args):
|
def convert_to_file_option(self, parser, args):
|
||||||
return args.get(self.name.replace("-", "_"), parser.get_default(self.name))
|
return args.get(self.name.replace("-", "_"), parser.get_default(self.name))
|
||||||
|
|
||||||
|
def validate(self, config_options, cli_args):
|
||||||
|
pass
|
||||||
|
|
||||||
class OptionInfoFlag(AbstractOptionInfo):
|
class OptionInfoFlag(AbstractOptionInfo):
|
||||||
def __init__(self, name, **parser_args):
|
def __init__(self, name, **parser_args):
|
||||||
self.name = name
|
self.name = name
|
||||||
@ -127,6 +139,9 @@ class OptionInfoFlag(AbstractOptionInfo):
|
|||||||
def convert_to_file_option(self, parser, args):
|
def convert_to_file_option(self, parser, args):
|
||||||
return args.get(self.name.replace("-", "_"), parser.get_default(self.name) or False)
|
return args.get(self.name.replace("-", "_"), parser.get_default(self.name) or False)
|
||||||
|
|
||||||
|
def validate(self, config_options, cli_args):
|
||||||
|
pass
|
||||||
|
|
||||||
class OptionInfoEnum(AbstractOptionInfo):
|
class OptionInfoEnum(AbstractOptionInfo):
|
||||||
def __init__(self, name, options, help=None, empty_help=None):
|
def __init__(self, name, options, help=None, empty_help=None):
|
||||||
self.name = name
|
self.name = name
|
||||||
@ -182,6 +197,12 @@ class OptionInfoEnum(AbstractOptionInfo):
|
|||||||
return option.name
|
return option.name
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def validate(self, config_options, cli_args):
|
||||||
|
set_by_cli = any(o for o in self.options if cli_args.get(o.option_name.replace("-", "_")) is not None)
|
||||||
|
if set_by_cli:
|
||||||
|
for option in self.options:
|
||||||
|
config_options[option.option_name.replace("-", "_")] = False
|
||||||
|
|
||||||
class OptionInfoEnumChoice:
|
class OptionInfoEnumChoice:
|
||||||
name: str
|
name: str
|
||||||
option_name: str
|
option_name: str
|
||||||
@ -225,6 +246,9 @@ class OptionInfoRaw:
|
|||||||
def convert_to_file_option(self, parser, args):
|
def convert_to_file_option(self, parser, args):
|
||||||
return args.get(self.name.replace("-", "_"), {})
|
return args.get(self.name.replace("-", "_"), {})
|
||||||
|
|
||||||
|
def validate(self, config_options, cli_args):
|
||||||
|
pass
|
||||||
|
|
||||||
#
|
#
|
||||||
# Config options
|
# Config options
|
||||||
#
|
#
|
||||||
@ -238,7 +262,7 @@ CONFIG_OPTIONS = [
|
|||||||
help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'."),
|
help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'."),
|
||||||
]),
|
]),
|
||||||
("files", [
|
("files", [
|
||||||
OptionInfoRaw("extra-model-paths-config", help="Extra paths to scan for model files."),
|
OptionInfoRaw("extra-model-paths", help="Extra paths to scan for model files."),
|
||||||
OptionInfo("output-directory", type=str, default=None, help="Set the ComfyUI output directory. Leave empty to use the default."),
|
OptionInfo("output-directory", type=str, default=None, help="Set the ComfyUI output directory. Leave empty to use the default."),
|
||||||
]),
|
]),
|
||||||
("behavior", [
|
("behavior", [
|
||||||
@ -321,8 +345,7 @@ def recursive_delete_comment_attribs(d):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
class ComfyConfigLoader:
|
class ComfyConfigLoader:
|
||||||
def __init__(self, config_path):
|
def __init__(self):
|
||||||
self.config_path = config_path
|
|
||||||
self.option_infos = CONFIG_OPTIONS
|
self.option_infos = CONFIG_OPTIONS
|
||||||
self.parser = make_config_parser(self.option_infos)
|
self.parser = make_config_parser(self.option_infos)
|
||||||
|
|
||||||
@ -338,9 +361,8 @@ class ComfyConfigLoader:
|
|||||||
|
|
||||||
return defaults
|
return defaults
|
||||||
|
|
||||||
def load_from_file(self, yaml_path):
|
def load_from_string(self, raw_config):
|
||||||
with open(yaml_path, 'r') as stream:
|
raw_config = yaml.load(raw_config)
|
||||||
raw_config = yaml.load(stream)
|
|
||||||
|
|
||||||
config = {}
|
config = {}
|
||||||
root = raw_config.get("config", {})
|
root = raw_config.get("config", {})
|
||||||
@ -410,9 +432,9 @@ class ComfyConfigLoader:
|
|||||||
config[category] = d
|
config[category] = d
|
||||||
return { "config": config }
|
return { "config": config }
|
||||||
|
|
||||||
def save_config(self, args):
|
def save_config(self, config_path, args):
|
||||||
options = self.convert_args_to_options(args)
|
options = self.convert_args_to_options(args)
|
||||||
with open(self.config_path, 'w') as f:
|
with open(config_path, 'w') as f:
|
||||||
yaml.dump(options, f)
|
yaml.dump(options, f)
|
||||||
|
|
||||||
def get_cli_arguments(self, argv):
|
def get_cli_arguments(self, argv):
|
||||||
@ -426,29 +448,41 @@ class ComfyConfigLoader:
|
|||||||
suppressed_parser = make_config_parser(self.option_infos, suppress=True)
|
suppressed_parser = make_config_parser(self.option_infos, suppress=True)
|
||||||
return vars(suppressed_parser.parse_args(argv))
|
return vars(suppressed_parser.parse_args(argv))
|
||||||
|
|
||||||
def parse_args(self, argv):
|
def parse_args_with_file(self, yaml_path, argv):
|
||||||
|
if not os.path.isfile(yaml_path):
|
||||||
|
print(f"Warning: no config file at path '{yaml_path}', creating it")
|
||||||
|
raw_config = "{}"
|
||||||
|
else:
|
||||||
|
with open(yaml_path, 'r') as stream:
|
||||||
|
raw_config = stream.read()
|
||||||
|
return self.parse_args_with_string(raw_config, argv, save_config_file=yaml_path)
|
||||||
|
|
||||||
|
def parse_args_with_string(self, config_string, argv, save_config_file=None):
|
||||||
defaults = self.get_arg_defaults()
|
defaults = self.get_arg_defaults()
|
||||||
|
|
||||||
if not os.path.isfile(self.config_path):
|
config_options = self.load_from_string(config_string)
|
||||||
print(f"Warning: no config file at path '{self.config_path}', creating it")
|
|
||||||
config_options = {}
|
|
||||||
else:
|
|
||||||
config_options = self.load_from_file(self.config_path)
|
|
||||||
|
|
||||||
config_options = dict(merge_dicts(defaults, config_options))
|
config_options = dict(merge_dicts(defaults, config_options))
|
||||||
self.save_config(config_options)
|
if save_config_file:
|
||||||
|
self.save_config(save_config_file, config_options)
|
||||||
|
|
||||||
cli_args = self.get_cli_arguments(argv)
|
cli_args = self.get_cli_arguments(argv)
|
||||||
|
|
||||||
|
for category, options in self.option_infos:
|
||||||
|
for option in options:
|
||||||
|
option.validate(config_options, cli_args)
|
||||||
|
|
||||||
args = dict(merge_dicts(config_options, cli_args))
|
args = dict(merge_dicts(config_options, cli_args))
|
||||||
return argparse.Namespace(**args)
|
return argparse.Namespace(**args)
|
||||||
|
|
||||||
#
|
args = {}
|
||||||
# Load config and CLI args
|
|
||||||
#
|
|
||||||
|
|
||||||
config_loader = ComfyConfigLoader(folder_paths.default_config_path)
|
if "pytest" not in sys.modules:
|
||||||
args = config_loader.parse_args(sys.argv[1:])
|
#
|
||||||
|
# Load config and CLI args
|
||||||
|
#
|
||||||
|
|
||||||
if args.windows_standalone_build:
|
config_loader = ComfyConfigLoader()
|
||||||
args.auto_launch = True
|
args = config_loader.parse_args_with_file(folder_paths.default_config_path, sys.argv[1:])
|
||||||
|
|
||||||
|
if args.windows_standalone_build:
|
||||||
|
args.auto_launch = True
|
||||||
|
|||||||
@ -15,7 +15,7 @@ config:
|
|||||||
files:
|
files:
|
||||||
|
|
||||||
# Extra paths to scan for model files.
|
# Extra paths to scan for model files.
|
||||||
extra_model_paths_config:
|
extra_model_paths:
|
||||||
a111:
|
a111:
|
||||||
base_path: path/to/stable-diffusion-webui/
|
base_path: path/to/stable-diffusion-webui/
|
||||||
checkpoints: models/Stable-diffusion
|
checkpoints: models/Stable-diffusion
|
||||||
|
|||||||
9
conftest.py
Normal file
9
conftest.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import sys, os
|
||||||
|
|
||||||
|
# Make sure that the application source directory (this directory's parent) is
|
||||||
|
# on sys.path.
|
||||||
|
|
||||||
|
here = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
sys.path.insert(0, here)
|
||||||
@ -1,25 +0,0 @@
|
|||||||
#Rename this to extra_model_paths.yaml and ComfyUI will load it
|
|
||||||
|
|
||||||
#config for a1111 ui
|
|
||||||
#all you have to do is change the base_path to where yours is installed
|
|
||||||
a111:
|
|
||||||
base_path: path/to/stable-diffusion-webui/
|
|
||||||
|
|
||||||
checkpoints: models/Stable-diffusion
|
|
||||||
configs: models/Stable-diffusion
|
|
||||||
vae: models/VAE
|
|
||||||
loras: models/Lora
|
|
||||||
upscale_models: |
|
|
||||||
models/ESRGAN
|
|
||||||
models/SwinIR
|
|
||||||
embeddings: embeddings
|
|
||||||
hypernetworks: models/hypernetworks
|
|
||||||
controlnet: models/ControlNet
|
|
||||||
|
|
||||||
#other_ui:
|
|
||||||
# base_path: path/to/ui
|
|
||||||
# checkpoints: models/checkpoints
|
|
||||||
# gligen: models/gligen
|
|
||||||
# custom_nodes: path/custom_nodes
|
|
||||||
|
|
||||||
|
|
||||||
7
main.py
7
main.py
@ -3,10 +3,13 @@ import itertools
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import threading
|
import threading
|
||||||
|
import pprint
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
|
print("Configuration: " + str(vars(args)))
|
||||||
|
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
import logging
|
import logging
|
||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
@ -82,8 +85,8 @@ if __name__ == "__main__":
|
|||||||
if os.path.isfile(extra_model_paths_config_path):
|
if os.path.isfile(extra_model_paths_config_path):
|
||||||
load_extra_path_config_file(extra_model_paths_config_path)
|
load_extra_path_config_file(extra_model_paths_config_path)
|
||||||
|
|
||||||
if args.extra_model_paths_config:
|
if args.extra_model_paths:
|
||||||
load_extra_path_config(args.extra_model_paths_config)
|
load_extra_path_config(args.extra_model_paths)
|
||||||
|
|
||||||
init_custom_nodes()
|
init_custom_nodes()
|
||||||
server.add_routes()
|
server.add_routes()
|
||||||
|
|||||||
3
pytest.ini
Normal file
3
pytest.ini
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
[pytest]
|
||||||
|
addopts = -ra -q
|
||||||
|
testpaths = test
|
||||||
@ -9,3 +9,4 @@ pytorch_lightning
|
|||||||
aiohttp
|
aiohttp
|
||||||
accelerate
|
accelerate
|
||||||
ruamel.yaml
|
ruamel.yaml
|
||||||
|
pytest
|
||||||
|
|||||||
79
test/test_cli_args.py
Normal file
79
test/test_cli_args.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
from comfy.cli_args import ComfyConfigLoader
|
||||||
|
|
||||||
|
|
||||||
|
def test_defaults():
|
||||||
|
config = "{}"
|
||||||
|
argv = []
|
||||||
|
|
||||||
|
args = ComfyConfigLoader().parse_args_with_string(config, argv)
|
||||||
|
|
||||||
|
assert args.listen == "127.0.0.1"
|
||||||
|
assert args.novram == False
|
||||||
|
|
||||||
|
|
||||||
|
def test_config():
|
||||||
|
config = """
|
||||||
|
config:
|
||||||
|
network:
|
||||||
|
listen: 0.0.0.0
|
||||||
|
"""
|
||||||
|
argv = []
|
||||||
|
|
||||||
|
args = ComfyConfigLoader().parse_args_with_string(config, argv)
|
||||||
|
|
||||||
|
assert args.listen == "0.0.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_args_overrides_config():
|
||||||
|
config = """
|
||||||
|
config:
|
||||||
|
network:
|
||||||
|
listen: 0.0.0.0
|
||||||
|
"""
|
||||||
|
argv = ["--listen", "192.168.1.100"]
|
||||||
|
|
||||||
|
args = ComfyConfigLoader().parse_args_with_string(config, argv)
|
||||||
|
|
||||||
|
assert args.listen == "192.168.1.100"
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_enum():
|
||||||
|
config = """
|
||||||
|
config:
|
||||||
|
pytorch:
|
||||||
|
cross_attention: split
|
||||||
|
"""
|
||||||
|
argv = []
|
||||||
|
|
||||||
|
args = ComfyConfigLoader().parse_args_with_string(config, argv)
|
||||||
|
|
||||||
|
assert args.use_split_cross_attention is True
|
||||||
|
assert args.use_pytorch_cross_attention is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_enum_default():
|
||||||
|
config = """
|
||||||
|
config:
|
||||||
|
pytorch:
|
||||||
|
cross_attention:
|
||||||
|
"""
|
||||||
|
argv = []
|
||||||
|
|
||||||
|
args = ComfyConfigLoader().parse_args_with_string(config, argv)
|
||||||
|
|
||||||
|
assert args.use_split_cross_attention is False
|
||||||
|
assert args.use_pytorch_cross_attention is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_enum_exclusive():
|
||||||
|
config = """
|
||||||
|
config:
|
||||||
|
pytorch:
|
||||||
|
cross_attention: split
|
||||||
|
"""
|
||||||
|
argv = ["--use-pytorch-cross-attention"]
|
||||||
|
|
||||||
|
args = ComfyConfigLoader().parse_args_with_string(config, argv)
|
||||||
|
|
||||||
|
assert args.use_split_cross_attention is False
|
||||||
|
assert args.use_pytorch_cross_attention is True
|
||||||
Loading…
Reference in New Issue
Block a user