Test and fix cli args issues

This commit is contained in:
doctorpangloss 2025-10-22 15:03:01 -07:00
parent 6954e3e247
commit 95d8ca6c53
3 changed files with 173 additions and 16 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import logging
import sys
from comfy.cli_args_types import FlattenAndAppendAction
from importlib.metadata import entry_points
from types import ModuleType
from typing import Optional
@ -36,8 +37,8 @@ def _create_parser() -> EnhancedConfigArgParser:
help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+',
action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--extra-model-paths-config", type=str, default=[], metavar="PATH", nargs='+',
action=FlattenAndAppendAction, help="Load one or more extra_model_paths.yaml files. Can be specified multiple times or as a comma-separated list.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
parser.add_argument("--temp-directory", type=str, default=None,
help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.")

View File

@ -73,7 +73,6 @@ class Configuration(dict):
temp_directory (Optional[str]): Temporary directory for processing.
input_directory (Optional[str]): Directory for input files. When this is a relative path, it will be looked up relative to the cwd (current working directory) and all of the base_paths.
auto_launch (bool): Auto-launch UI in the default browser. Defaults to False.
disable_auto_launch (bool): Disable auto-launching the browser.
cuda_device (Optional[int]): CUDA device ID. None means default device.
cuda_malloc (bool): Enable cudaMallocAsync. Defaults to True in applicable setups.
disable_cuda_malloc (bool): Disable cudaMallocAsync.
@ -188,7 +187,6 @@ class Configuration(dict):
self.temp_directory: Optional[str] = None
self.input_directory: Optional[str] = None
self.auto_launch: bool = False
self.disable_auto_launch: bool = False
self.cuda_device: Optional[int] = None
self.cuda_malloc: bool = True
self.disable_cuda_malloc: bool = True
@ -352,7 +350,7 @@ class Configuration(dict):
class EnumAction(argparse.Action):
"""
Argparse action for handling Enums
Argparse action for handling Enums in a case-insensitive manner.
"""
def __init__(self, **kwargs):
@ -362,23 +360,33 @@ class EnumAction(argparse.Action):
# Ensure an Enum subclass is provided
if enum_type is None:
raise ValueError("type must be assigned an Enum when using EnumAction")
enum_type: Any
if not issubclass(enum_type, enum.Enum):
raise TypeError("type must be an Enum when using EnumAction")
# Generate choices from the Enum
choices = tuple(e.value for e in enum_type)
kwargs.setdefault("choices", choices)
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
super(EnumAction, self).__init__(**kwargs)
self._enum = enum_type
# Generate choices from the Enum for the help message
choices = tuple(e.value for e in enum_type)
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
# We handle choices ourselves for case-insensitivity, so remove it before calling super.
if "choices" in kwargs:
del kwargs["choices"]
super(EnumAction, self).__init__(**kwargs)
self._choices = choices
def __call__(self, parser, namespace, values, option_string=None):
# Convert value back into an Enum
value = self._enum(values)
setattr(namespace, self.dest, value)
# Convert value back into an Enum, case-insensitively
value_lower = values.lower()
for member in self._enum:
if member.value.lower() == value_lower:
setattr(namespace, self.dest, member)
return
# If no match found, raise an error
msg = f"invalid choice: {values!r} (choose from {', '.join(self._choices)})"
raise argparse.ArgumentError(self, msg)
class ParsedArgs(NamedTuple):
@ -401,3 +409,25 @@ class EnhancedConfigArgParser(configargparse.ArgParser):
namespace, unknown_args = super().parse_known_args(args, namespace, **kwargs)
return ParsedArgs(namespace, unknown_args, config_files)
class FlattenAndAppendAction(argparse.Action):
"""
Custom action to handle comma-separated values and multiple invocations
of the same argument, flattening them into a single list.
"""
def __call__(self, parser, namespace, values, option_string=None):
items = getattr(namespace, self.dest, None)
if items is None:
items = []
else:
# Make a copy if it's not the first time, to avoid modifying the default.
items = items[:]
# 'values' will be a list of strings because of nargs='+'
for value in values:
# Split comma-separated strings and add them to the list
items.extend(item.strip() for item in value.split(','))
# Set the flattened list back to the namespace.
setattr(namespace, self.dest, items)

126
tests/unit/test_cli_args.py Normal file
View File

@ -0,0 +1,126 @@
import pytest
from unittest.mock import patch
from comfy import cli_args
from comfy.cli_args_types import LatentPreviewMethod
# Helper function to parse args and return the Configuration object
def _parse_test_args(args_list):
parser = cli_args._create_parser()
# The `args_parsing=True` makes it use the provided list.
with patch.object(parser, 'parse_known_args_with_config_files', return_value=(parser.parse_known_args(args_list)[0], [], [])):
return cli_args._parse_args(parser, args_parsing=True)
@pytest.mark.parametrize("args, expected", [
([], []),
(['--extra-model-paths-config', 'a'], ['a']),
(['--extra-model-paths-config', 'a', '--extra-model-paths-config', 'b'], ['a', 'b']),
(['--extra-model-paths-config', 'a,b'], ['a', 'b']),
(['--extra-model-paths-config', 'a,b', '--extra-model-paths-config', 'c'], ['a', 'b', 'c']),
(['--extra-model-paths-config', ' a , b ', '--extra-model-paths-config', 'c'], ['a', 'b', 'c']),
(['--extra-model-paths-config', 'a,b', 'c'], ['a', 'b', 'c']),
])
def test_extra_model_paths_config(args, expected):
"""Test that extra_model_paths_config is parsed correctly."""
config = _parse_test_args(args)
assert config.extra_model_paths_config == expected
def test_default_values():
"""Test that default values are set correctly when no args are provided."""
config = _parse_test_args([])
assert config.listen == "127.0.0.1"
assert config.port == 8188
assert config.auto_launch is False
assert config.extra_model_paths_config == []
assert config.preview_method == LatentPreviewMethod.Auto
assert config.logging_level == 'INFO'
assert config.multi_user is False
assert config.disable_xformers is False
assert config.gpu_only is False
assert config.highvram is False
assert config.lowvram is False
assert config.normalvram is False
assert config.novram is False
assert config.cpu is False
def test_listen_and_port():
"""Test --listen and --port arguments."""
config = _parse_test_args(['--listen', '0.0.0.0', '--port', '8000'])
assert config.listen == '0.0.0.0'
assert config.port == 8000
def test_listen_no_arg():
"""Test --listen without an argument."""
config = _parse_test_args(['--listen'])
assert config.listen == '0.0.0.0,::'
def test_auto_launch_flags():
"""Test --auto-launch and --disable-auto-launch flags."""
config_auto = _parse_test_args(['--auto-launch'])
assert config_auto.auto_launch is True
config_disable = _parse_test_args(['--disable-auto-launch'])
assert config_disable.auto_launch is False
# Test that --disable-auto-launch overrides --auto-launch if both are present
# The order matters, argparse behavior. Last one wins for store_true/false.
config_both_1 = _parse_test_args(['--auto-launch', '--disable-auto-launch'])
assert config_both_1.auto_launch is False
config_both_2 = _parse_test_args(['--disable-auto-launch', '--auto-launch'])
assert config_both_2.auto_launch is False
def test_windows_standalone_build_enables_auto_launch():
"""Test that --windows-standalone-build enables auto-launch."""
config = _parse_test_args(['--windows-standalone-build'])
assert config.windows_standalone_build is True
assert config.auto_launch is True
def test_windows_standalone_build_with_disable_auto_launch():
"""Test that --disable-auto-launch overrides --windows-standalone-build's auto-launch."""
config = _parse_test_args(['--windows-standalone-build', '--disable-auto-launch'])
assert config.windows_standalone_build is True
assert config.auto_launch is False
def test_force_fp16_enables_fp16_unet():
"""Test that --force-fp16 enables --fp16-unet."""
config = _parse_test_args(['--force-fp16'])
assert config.force_fp16 is True
assert config.fp16_unet is True
@pytest.mark.parametrize("vram_arg, expected_true_field", [
('--gpu-only', 'gpu_only'),
('--highvram', 'highvram'),
('--normalvram', 'normalvram'),
('--lowvram', 'lowvram'),
('--novram', 'novram'),
('--cpu', 'cpu'),
])
def test_vram_modes(vram_arg, expected_true_field):
"""Test mutually exclusive VRAM mode arguments."""
config = _parse_test_args([vram_arg])
all_vram_fields = ['gpu_only', 'highvram', 'normalvram', 'lowvram', 'novram', 'cpu']
for field in all_vram_fields:
if field == expected_true_field:
assert getattr(config, field) is True
else:
assert getattr(config, field) is False
def test_preview_method():
"""Test the --preview-method argument."""
config = _parse_test_args(['--preview-method', 'TAESD'])
assert config.preview_method == LatentPreviewMethod.TAESD
def test_logging_level():
"""Test the --logging-level argument."""
config = _parse_test_args(['--logging-level', 'debug'])
assert config.logging_level == 'DEBUG'
def test_multi_user():
"""Test the --multi-user flag."""
config = _parse_test_args(['--multi-user'])
assert config.multi_user is True
def test_disable_xformers():
"""Test the --disable-xformers flag."""
config = _parse_test_args(['--disable-xformers'])
assert config.disable_xformers is True