Improve configuration via files, including automatically updating configuration when configuration files change.

This commit is contained in:
doctorpangloss 2024-07-08 10:01:08 -07:00
parent e88f458a70
commit a1fee05e60
3 changed files with 155 additions and 49 deletions

View File

@ -1,52 +1,23 @@
from __future__ import annotations from __future__ import annotations
import enum
import logging import logging
import os
import sys import sys
from importlib.metadata import entry_points from importlib.metadata import entry_points
from types import ModuleType from types import ModuleType
from typing import Optional, Any from typing import Optional, List
import configargparse as argparse import configargparse as argparse
from watchdog.observers import Observer
from . import __version__ from . import __version__
from . import options from . import options
from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, ConfigChangeHandler, EnumAction, \
EnhancedConfigArgParser
class EnumAction(argparse.Action): def _create_parser() -> EnhancedConfigArgParser:
""" parser = EnhancedConfigArgParser(default_config_files=['config.yaml', 'config.json'],
Argparse action for handling Enums
"""
def __init__(self, **kwargs):
# Pop off the type value
enum_type = kwargs.pop("type", None)
# Ensure an Enum subclass is provided
if enum_type is None:
raise ValueError("type must be assigned an Enum when using EnumAction")
enum_type: Any
if not issubclass(enum_type, enum.Enum):
raise TypeError("type must be an Enum when using EnumAction")
# Generate choices from the Enum
choices = tuple(e.value for e in enum_type)
kwargs.setdefault("choices", choices)
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
super(EnumAction, self).__init__(**kwargs)
self._enum = enum_type
def __call__(self, parser, namespace, values, option_string=None):
# Convert value back into an Enum
value = self._enum(values)
setattr(namespace, self.dest, value)
def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(default_config_files=['config.yaml', 'config.json'],
auto_env_var_prefix='COMFYUI_', auto_env_var_prefix='COMFYUI_',
args_for_setting_config_path=["-c", "--config"], args_for_setting_config_path=["-c", "--config"],
add_env_var_help=True, add_config_file_help=True, add_help=True, add_env_var_help=True, add_config_file_help=True, add_help=True,
@ -205,14 +176,14 @@ def create_parser() -> argparse.ArgumentParser:
return parser return parser
def parse_args(parser: Optional[argparse.ArgumentParser] = None) -> Configuration: def _parse_args(parser: Optional[argparse.ArgumentParser] = None) -> Configuration:
if parser is None: if parser is None:
parser = create_parser() parser = _create_parser()
if options.args_parsing: if options.args_parsing:
args, _ = parser.parse_known_args() args, _, config_files = parser.parse_known_args_with_config_files()
else: else:
args, _ = parser.parse_known_args([]) args, _, config_files = parser.parse_known_args_with_config_files([])
if args.windows_standalone_build: if args.windows_standalone_build:
args.auto_launch = True args.auto_launch = True
@ -225,8 +196,34 @@ def parse_args(parser: Optional[argparse.ArgumentParser] = None) -> Configuratio
logging_level = logging.DEBUG logging_level = logging.DEBUG
logging.basicConfig(format="%(message)s", level=logging_level) logging.basicConfig(format="%(message)s", level=logging_level)
configuration_obj = Configuration(**vars(args))
return Configuration(**vars(args)) configuration_obj.config_files = config_files
assert all(isinstance(config_file, str) for config_file in config_files)
# we always have to set up a watcher, even when there are no existing files
if len(config_files) > 0:
_setup_config_file_watcher(configuration_obj, parser, config_files)
return configuration_obj
args = parse_args() def _setup_config_file_watcher(config: Configuration, parser: EnhancedConfigArgParser, config_files: List[str]):
def update_config():
new_args, _, _ = parser.parse_known_args()
new_config = vars(new_args)
config.update(new_config)
handler = ConfigChangeHandler(config_files, update_config)
observer = Observer()
for config_file in config_files:
config_dir = os.path.dirname(config_file) or '.'
observer.schedule(handler, path=config_dir, recursive=False)
observer.start()
# Ensure the observer is stopped when the program exits
import atexit
atexit.register(observer.stop)
atexit.register(observer.join)
args = _parse_args()

View File

@ -1,10 +1,13 @@
# Define a class for your command-line arguments # Define a class for your command-line arguments
from __future__ import annotations
import copy
import enum import enum
from typing import Optional, List, Callable from typing import Optional, List, Callable, Any, Union, Mapping, NamedTuple
import configargparse
import configargparse as argparse import configargparse as argparse
from watchdog.events import FileSystemEventHandler
from . import __version__
ConfigurationExtender = Callable[[argparse.ArgParser], Optional[argparse.ArgParser]] ConfigurationExtender = Callable[[argparse.ArgParser], Optional[argparse.ArgParser]]
@ -16,12 +19,25 @@ class LatentPreviewMethod(enum.Enum):
TAESD = "taesd" TAESD = "taesd"
class ConfigChangeHandler(FileSystemEventHandler):
def __init__(self, config_file_paths: List[str], update_callback: Callable[[], None]):
self.config_file_paths = config_file_paths
self.update_callback = update_callback
def on_modified(self, event):
if not event.is_directory and event.src_path in self.config_file_paths:
self.update_callback()
ConfigObserver = Callable[[str, Any], None]
class Configuration(dict): class Configuration(dict):
""" """
Configuration options parsed from command-line arguments or config files. Configuration options parsed from command-line arguments or config files.
Attributes: Attributes:
config (Optional[str]): Path to the configuration file. config_files (Optional[List[str]]): Path to the configuration file(s) that were set in the arguments.
cwd (Optional[str]): Working directory. Defaults to the current directory. cwd (Optional[str]): Working directory. Defaults to the current directory.
listen (str): IP address to listen on. Defaults to "127.0.0.1". listen (str): IP address to listen on. Defaults to "127.0.0.1".
port (int): Port number for the server to listen on. Defaults to 8188. port (int): Port number for the server to listen on. Defaults to 8188.
@ -95,6 +111,8 @@ class Configuration(dict):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__() super().__init__()
self._observers: List[ConfigObserver] = []
self.config_files = []
self.cwd: Optional[str] = None self.cwd: Optional[str] = None
self.listen: str = "127.0.0.1" self.listen: str = "127.0.0.1"
self.port: int = 8188 self.port: int = 8188
@ -174,4 +192,94 @@ class Configuration(dict):
return self[item] return self[item]
def __setattr__(self, key, value): def __setattr__(self, key, value):
self[key] = value if key != "_observers":
old_value = self.get(key)
self[key] = value
if old_value != value:
self._notify_observers(key, value)
else:
super().__setattr__(key, value)
def update(self, __m: Union[Mapping[str, Any], None] = None, **kwargs):
if __m is None:
__m = {}
changes = {}
for k, v in dict(__m, **kwargs).items():
if k not in self or self[k] != v:
changes[k] = v
super().update(__m, **kwargs)
for k, v in changes.items():
self._notify_observers(k, v)
def register_observer(self, observer: ConfigObserver):
self._observers.append(observer)
def unregister_observer(self, observer: ConfigObserver):
self._observers.remove(observer)
def _notify_observers(self, key, value):
for observer in self._observers:
observer(key, value)
def __getstate__(self):
state = self.copy()
if "_observers" in state:
state.pop("_observers")
return state
def __setstate__(self, state):
self.update(state)
self._observers = []
class EnumAction(argparse.Action):
"""
Argparse action for handling Enums
"""
def __init__(self, **kwargs):
# Pop off the type value
enum_type = kwargs.pop("type", None)
# Ensure an Enum subclass is provided
if enum_type is None:
raise ValueError("type must be assigned an Enum when using EnumAction")
enum_type: Any
if not issubclass(enum_type, enum.Enum):
raise TypeError("type must be an Enum when using EnumAction")
# Generate choices from the Enum
choices = tuple(e.value for e in enum_type)
kwargs.setdefault("choices", choices)
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
super(EnumAction, self).__init__(**kwargs)
self._enum = enum_type
def __call__(self, parser, namespace, values, option_string=None):
# Convert value back into an Enum
value = self._enum(values)
setattr(namespace, self.dest, value)
class ParsedArgs(NamedTuple):
namespace: configargparse.Namespace
unknown_args: list[str]
config_file_paths: list[str]
class EnhancedConfigArgParser(configargparse.ArgParser):
def parse_known_args_with_config_files(self, args=None, namespace=None, **kwargs) -> ParsedArgs:
# usually the single method open
prev_open_func = self._config_file_open_func
config_files: List[str] = []
try:
self._config_file_open_func = lambda path: config_files.append(path)
self._open_config_files(args)
finally:
self._config_file_open_func = prev_open_func
namespace, unknown_args = super().parse_known_args(args, namespace, **kwargs)
return ParsedArgs(namespace, unknown_args, config_files)

View File

@ -53,4 +53,5 @@ wrapt>=1.16.0
certifi certifi
spandrel spandrel
numpy>=1.26.3,<2.0.0 numpy>=1.26.3,<2.0.0
soundfile soundfile
watchdog