diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a2679a517..9e15ced6f 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging import sys +from comfy.cli_args_types import FlattenAndAppendAction from importlib.metadata import entry_points from types import ModuleType from typing import Optional @@ -36,8 +37,8 @@ def _create_parser() -> EnhancedConfigArgParser: help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.") parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.") - parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', - action='append', help="Load one or more extra_model_paths.yaml files.") + parser.add_argument("--extra-model-paths-config", type=str, default=[], metavar="PATH", nargs='+', + action=FlattenAndAppendAction, help="Load one or more extra_model_paths.yaml files. Can be specified multiple times or as a comma-separated list.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.") parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.") diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 8f32defa6..b46baef68 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -73,7 +73,6 @@ class Configuration(dict): temp_directory (Optional[str]): Temporary directory for processing. input_directory (Optional[str]): Directory for input files. When this is a relative path, it will be looked up relative to the cwd (current working directory) and all of the base_paths. auto_launch (bool): Auto-launch UI in the default browser. Defaults to False. - disable_auto_launch (bool): Disable auto-launching the browser. cuda_device (Optional[int]): CUDA device ID. None means default device. cuda_malloc (bool): Enable cudaMallocAsync. Defaults to True in applicable setups. disable_cuda_malloc (bool): Disable cudaMallocAsync. @@ -188,7 +187,6 @@ class Configuration(dict): self.temp_directory: Optional[str] = None self.input_directory: Optional[str] = None self.auto_launch: bool = False - self.disable_auto_launch: bool = False self.cuda_device: Optional[int] = None self.cuda_malloc: bool = True self.disable_cuda_malloc: bool = True @@ -352,7 +350,7 @@ class Configuration(dict): class EnumAction(argparse.Action): """ - Argparse action for handling Enums + Argparse action for handling Enums in a case-insensitive manner. """ def __init__(self, **kwargs): @@ -362,23 +360,33 @@ class EnumAction(argparse.Action): # Ensure an Enum subclass is provided if enum_type is None: raise ValueError("type must be assigned an Enum when using EnumAction") - enum_type: Any if not issubclass(enum_type, enum.Enum): raise TypeError("type must be an Enum when using EnumAction") - # Generate choices from the Enum - choices = tuple(e.value for e in enum_type) - kwargs.setdefault("choices", choices) - kwargs.setdefault("metavar", f"[{','.join(list(choices))}]") - - super(EnumAction, self).__init__(**kwargs) - self._enum = enum_type + # Generate choices from the Enum for the help message + choices = tuple(e.value for e in enum_type) + kwargs.setdefault("metavar", f"[{','.join(list(choices))}]") + + # We handle choices ourselves for case-insensitivity, so remove it before calling super. + if "choices" in kwargs: + del kwargs["choices"] + + super(EnumAction, self).__init__(**kwargs) + self._choices = choices + def __call__(self, parser, namespace, values, option_string=None): - # Convert value back into an Enum - value = self._enum(values) - setattr(namespace, self.dest, value) + # Convert value back into an Enum, case-insensitively + value_lower = values.lower() + for member in self._enum: + if member.value.lower() == value_lower: + setattr(namespace, self.dest, member) + return + + # If no match found, raise an error + msg = f"invalid choice: {values!r} (choose from {', '.join(self._choices)})" + raise argparse.ArgumentError(self, msg) class ParsedArgs(NamedTuple): @@ -401,3 +409,25 @@ class EnhancedConfigArgParser(configargparse.ArgParser): namespace, unknown_args = super().parse_known_args(args, namespace, **kwargs) return ParsedArgs(namespace, unknown_args, config_files) + + +class FlattenAndAppendAction(argparse.Action): + """ + Custom action to handle comma-separated values and multiple invocations + of the same argument, flattening them into a single list. + """ + def __call__(self, parser, namespace, values, option_string=None): + items = getattr(namespace, self.dest, None) + if items is None: + items = [] + else: + # Make a copy if it's not the first time, to avoid modifying the default. + items = items[:] + + # 'values' will be a list of strings because of nargs='+' + for value in values: + # Split comma-separated strings and add them to the list + items.extend(item.strip() for item in value.split(',')) + + # Set the flattened list back to the namespace. + setattr(namespace, self.dest, items) diff --git a/tests/unit/test_cli_args.py b/tests/unit/test_cli_args.py new file mode 100644 index 000000000..398de8a54 --- /dev/null +++ b/tests/unit/test_cli_args.py @@ -0,0 +1,126 @@ +import pytest +from unittest.mock import patch +from comfy import cli_args +from comfy.cli_args_types import LatentPreviewMethod + +# Helper function to parse args and return the Configuration object +def _parse_test_args(args_list): + parser = cli_args._create_parser() + # The `args_parsing=True` makes it use the provided list. + with patch.object(parser, 'parse_known_args_with_config_files', return_value=(parser.parse_known_args(args_list)[0], [], [])): + return cli_args._parse_args(parser, args_parsing=True) + +@pytest.mark.parametrize("args, expected", [ + ([], []), + (['--extra-model-paths-config', 'a'], ['a']), + (['--extra-model-paths-config', 'a', '--extra-model-paths-config', 'b'], ['a', 'b']), + (['--extra-model-paths-config', 'a,b'], ['a', 'b']), + (['--extra-model-paths-config', 'a,b', '--extra-model-paths-config', 'c'], ['a', 'b', 'c']), + (['--extra-model-paths-config', ' a , b ', '--extra-model-paths-config', 'c'], ['a', 'b', 'c']), + (['--extra-model-paths-config', 'a,b', 'c'], ['a', 'b', 'c']), +]) +def test_extra_model_paths_config(args, expected): + """Test that extra_model_paths_config is parsed correctly.""" + config = _parse_test_args(args) + assert config.extra_model_paths_config == expected + +def test_default_values(): + """Test that default values are set correctly when no args are provided.""" + config = _parse_test_args([]) + assert config.listen == "127.0.0.1" + assert config.port == 8188 + assert config.auto_launch is False + assert config.extra_model_paths_config == [] + assert config.preview_method == LatentPreviewMethod.Auto + assert config.logging_level == 'INFO' + assert config.multi_user is False + assert config.disable_xformers is False + assert config.gpu_only is False + assert config.highvram is False + assert config.lowvram is False + assert config.normalvram is False + assert config.novram is False + assert config.cpu is False + +def test_listen_and_port(): + """Test --listen and --port arguments.""" + config = _parse_test_args(['--listen', '0.0.0.0', '--port', '8000']) + assert config.listen == '0.0.0.0' + assert config.port == 8000 + +def test_listen_no_arg(): + """Test --listen without an argument.""" + config = _parse_test_args(['--listen']) + assert config.listen == '0.0.0.0,::' + +def test_auto_launch_flags(): + """Test --auto-launch and --disable-auto-launch flags.""" + config_auto = _parse_test_args(['--auto-launch']) + assert config_auto.auto_launch is True + + config_disable = _parse_test_args(['--disable-auto-launch']) + assert config_disable.auto_launch is False + + # Test that --disable-auto-launch overrides --auto-launch if both are present + # The order matters, argparse behavior. Last one wins for store_true/false. + config_both_1 = _parse_test_args(['--auto-launch', '--disable-auto-launch']) + assert config_both_1.auto_launch is False + + config_both_2 = _parse_test_args(['--disable-auto-launch', '--auto-launch']) + assert config_both_2.auto_launch is False + +def test_windows_standalone_build_enables_auto_launch(): + """Test that --windows-standalone-build enables auto-launch.""" + config = _parse_test_args(['--windows-standalone-build']) + assert config.windows_standalone_build is True + assert config.auto_launch is True + +def test_windows_standalone_build_with_disable_auto_launch(): + """Test that --disable-auto-launch overrides --windows-standalone-build's auto-launch.""" + config = _parse_test_args(['--windows-standalone-build', '--disable-auto-launch']) + assert config.windows_standalone_build is True + assert config.auto_launch is False + +def test_force_fp16_enables_fp16_unet(): + """Test that --force-fp16 enables --fp16-unet.""" + config = _parse_test_args(['--force-fp16']) + assert config.force_fp16 is True + assert config.fp16_unet is True + +@pytest.mark.parametrize("vram_arg, expected_true_field", [ + ('--gpu-only', 'gpu_only'), + ('--highvram', 'highvram'), + ('--normalvram', 'normalvram'), + ('--lowvram', 'lowvram'), + ('--novram', 'novram'), + ('--cpu', 'cpu'), +]) +def test_vram_modes(vram_arg, expected_true_field): + """Test mutually exclusive VRAM mode arguments.""" + config = _parse_test_args([vram_arg]) + all_vram_fields = ['gpu_only', 'highvram', 'normalvram', 'lowvram', 'novram', 'cpu'] + for field in all_vram_fields: + if field == expected_true_field: + assert getattr(config, field) is True + else: + assert getattr(config, field) is False + +def test_preview_method(): + """Test the --preview-method argument.""" + config = _parse_test_args(['--preview-method', 'TAESD']) + assert config.preview_method == LatentPreviewMethod.TAESD + +def test_logging_level(): + """Test the --logging-level argument.""" + config = _parse_test_args(['--logging-level', 'debug']) + assert config.logging_level == 'DEBUG' + +def test_multi_user(): + """Test the --multi-user flag.""" + config = _parse_test_args(['--multi-user']) + assert config.multi_user is True + +def test_disable_xformers(): + """Test the --disable-xformers flag.""" + config = _parse_test_args(['--disable-xformers']) + assert config.disable_xformers is True