Improve configuration via files, including automatically updating configuration when configuration files change.

This commit is contained in:
doctorpangloss 2024-07-08 10:01:08 -07:00
parent e88f458a70
commit a1fee05e60
3 changed files with 155 additions and 49 deletions

View File

@ -1,52 +1,23 @@
from __future__ import annotations
import enum
import logging
import os
import sys
from importlib.metadata import entry_points
from types import ModuleType
from typing import Optional, Any
from typing import Optional, List
import configargparse as argparse
from watchdog.observers import Observer
from . import __version__
from . import options
from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender
from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, ConfigChangeHandler, EnumAction, \
EnhancedConfigArgParser
class EnumAction(argparse.Action):
"""
Argparse action for handling Enums
"""
def __init__(self, **kwargs):
# Pop off the type value
enum_type = kwargs.pop("type", None)
# Ensure an Enum subclass is provided
if enum_type is None:
raise ValueError("type must be assigned an Enum when using EnumAction")
enum_type: Any
if not issubclass(enum_type, enum.Enum):
raise TypeError("type must be an Enum when using EnumAction")
# Generate choices from the Enum
choices = tuple(e.value for e in enum_type)
kwargs.setdefault("choices", choices)
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
super(EnumAction, self).__init__(**kwargs)
self._enum = enum_type
def __call__(self, parser, namespace, values, option_string=None):
# Convert value back into an Enum
value = self._enum(values)
setattr(namespace, self.dest, value)
def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(default_config_files=['config.yaml', 'config.json'],
def _create_parser() -> EnhancedConfigArgParser:
parser = EnhancedConfigArgParser(default_config_files=['config.yaml', 'config.json'],
auto_env_var_prefix='COMFYUI_',
args_for_setting_config_path=["-c", "--config"],
add_env_var_help=True, add_config_file_help=True, add_help=True,
@ -205,14 +176,14 @@ def create_parser() -> argparse.ArgumentParser:
return parser
def parse_args(parser: Optional[argparse.ArgumentParser] = None) -> Configuration:
def _parse_args(parser: Optional[argparse.ArgumentParser] = None) -> Configuration:
if parser is None:
parser = create_parser()
parser = _create_parser()
if options.args_parsing:
args, _ = parser.parse_known_args()
args, _, config_files = parser.parse_known_args_with_config_files()
else:
args, _ = parser.parse_known_args([])
args, _, config_files = parser.parse_known_args_with_config_files([])
if args.windows_standalone_build:
args.auto_launch = True
@ -225,8 +196,34 @@ def parse_args(parser: Optional[argparse.ArgumentParser] = None) -> Configuratio
logging_level = logging.DEBUG
logging.basicConfig(format="%(message)s", level=logging_level)
return Configuration(**vars(args))
configuration_obj = Configuration(**vars(args))
configuration_obj.config_files = config_files
assert all(isinstance(config_file, str) for config_file in config_files)
# we always have to set up a watcher, even when there are no existing files
if len(config_files) > 0:
_setup_config_file_watcher(configuration_obj, parser, config_files)
return configuration_obj
args = parse_args()
def _setup_config_file_watcher(config: Configuration, parser: EnhancedConfigArgParser, config_files: List[str]):
def update_config():
new_args, _, _ = parser.parse_known_args()
new_config = vars(new_args)
config.update(new_config)
handler = ConfigChangeHandler(config_files, update_config)
observer = Observer()
for config_file in config_files:
config_dir = os.path.dirname(config_file) or '.'
observer.schedule(handler, path=config_dir, recursive=False)
observer.start()
# Ensure the observer is stopped when the program exits
import atexit
atexit.register(observer.stop)
atexit.register(observer.join)
args = _parse_args()

View File

@ -1,10 +1,13 @@
# Define a class for your command-line arguments
from __future__ import annotations
import copy
import enum
from typing import Optional, List, Callable
from typing import Optional, List, Callable, Any, Union, Mapping, NamedTuple
import configargparse
import configargparse as argparse
from . import __version__
from watchdog.events import FileSystemEventHandler
ConfigurationExtender = Callable[[argparse.ArgParser], Optional[argparse.ArgParser]]
@ -16,12 +19,25 @@ class LatentPreviewMethod(enum.Enum):
TAESD = "taesd"
class ConfigChangeHandler(FileSystemEventHandler):
def __init__(self, config_file_paths: List[str], update_callback: Callable[[], None]):
self.config_file_paths = config_file_paths
self.update_callback = update_callback
def on_modified(self, event):
if not event.is_directory and event.src_path in self.config_file_paths:
self.update_callback()
ConfigObserver = Callable[[str, Any], None]
class Configuration(dict):
"""
Configuration options parsed from command-line arguments or config files.
Attributes:
config (Optional[str]): Path to the configuration file.
config_files (Optional[List[str]]): Path to the configuration file(s) that were set in the arguments.
cwd (Optional[str]): Working directory. Defaults to the current directory.
listen (str): IP address to listen on. Defaults to "127.0.0.1".
port (int): Port number for the server to listen on. Defaults to 8188.
@ -95,6 +111,8 @@ class Configuration(dict):
def __init__(self, **kwargs):
super().__init__()
self._observers: List[ConfigObserver] = []
self.config_files = []
self.cwd: Optional[str] = None
self.listen: str = "127.0.0.1"
self.port: int = 8188
@ -174,4 +192,94 @@ class Configuration(dict):
return self[item]
def __setattr__(self, key, value):
self[key] = value
if key != "_observers":
old_value = self.get(key)
self[key] = value
if old_value != value:
self._notify_observers(key, value)
else:
super().__setattr__(key, value)
def update(self, __m: Union[Mapping[str, Any], None] = None, **kwargs):
if __m is None:
__m = {}
changes = {}
for k, v in dict(__m, **kwargs).items():
if k not in self or self[k] != v:
changes[k] = v
super().update(__m, **kwargs)
for k, v in changes.items():
self._notify_observers(k, v)
def register_observer(self, observer: ConfigObserver):
self._observers.append(observer)
def unregister_observer(self, observer: ConfigObserver):
self._observers.remove(observer)
def _notify_observers(self, key, value):
for observer in self._observers:
observer(key, value)
def __getstate__(self):
state = self.copy()
if "_observers" in state:
state.pop("_observers")
return state
def __setstate__(self, state):
self.update(state)
self._observers = []
class EnumAction(argparse.Action):
"""
Argparse action for handling Enums
"""
def __init__(self, **kwargs):
# Pop off the type value
enum_type = kwargs.pop("type", None)
# Ensure an Enum subclass is provided
if enum_type is None:
raise ValueError("type must be assigned an Enum when using EnumAction")
enum_type: Any
if not issubclass(enum_type, enum.Enum):
raise TypeError("type must be an Enum when using EnumAction")
# Generate choices from the Enum
choices = tuple(e.value for e in enum_type)
kwargs.setdefault("choices", choices)
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
super(EnumAction, self).__init__(**kwargs)
self._enum = enum_type
def __call__(self, parser, namespace, values, option_string=None):
# Convert value back into an Enum
value = self._enum(values)
setattr(namespace, self.dest, value)
class ParsedArgs(NamedTuple):
namespace: configargparse.Namespace
unknown_args: list[str]
config_file_paths: list[str]
class EnhancedConfigArgParser(configargparse.ArgParser):
def parse_known_args_with_config_files(self, args=None, namespace=None, **kwargs) -> ParsedArgs:
# usually the single method open
prev_open_func = self._config_file_open_func
config_files: List[str] = []
try:
self._config_file_open_func = lambda path: config_files.append(path)
self._open_config_files(args)
finally:
self._config_file_open_func = prev_open_func
namespace, unknown_args = super().parse_known_args(args, namespace, **kwargs)
return ParsedArgs(namespace, unknown_args, config_files)

View File

@ -53,4 +53,5 @@ wrapt>=1.16.0
certifi
spandrel
numpy>=1.26.3,<2.0.0
soundfile
soundfile
watchdog