mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Test and fix cli args issues
This commit is contained in:
parent
6954e3e247
commit
95d8ca6c53
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from comfy.cli_args_types import FlattenAndAppendAction
|
||||
from importlib.metadata import entry_points
|
||||
from types import ModuleType
|
||||
from typing import Optional
|
||||
@ -36,8 +37,8 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
||||
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
|
||||
parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
|
||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+',
|
||||
action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||
parser.add_argument("--extra-model-paths-config", type=str, default=[], metavar="PATH", nargs='+',
|
||||
action=FlattenAndAppendAction, help="Load one or more extra_model_paths.yaml files. Can be specified multiple times or as a comma-separated list.")
|
||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
|
||||
parser.add_argument("--temp-directory", type=str, default=None,
|
||||
help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.")
|
||||
|
||||
@ -73,7 +73,6 @@ class Configuration(dict):
|
||||
temp_directory (Optional[str]): Temporary directory for processing.
|
||||
input_directory (Optional[str]): Directory for input files. When this is a relative path, it will be looked up relative to the cwd (current working directory) and all of the base_paths.
|
||||
auto_launch (bool): Auto-launch UI in the default browser. Defaults to False.
|
||||
disable_auto_launch (bool): Disable auto-launching the browser.
|
||||
cuda_device (Optional[int]): CUDA device ID. None means default device.
|
||||
cuda_malloc (bool): Enable cudaMallocAsync. Defaults to True in applicable setups.
|
||||
disable_cuda_malloc (bool): Disable cudaMallocAsync.
|
||||
@ -188,7 +187,6 @@ class Configuration(dict):
|
||||
self.temp_directory: Optional[str] = None
|
||||
self.input_directory: Optional[str] = None
|
||||
self.auto_launch: bool = False
|
||||
self.disable_auto_launch: bool = False
|
||||
self.cuda_device: Optional[int] = None
|
||||
self.cuda_malloc: bool = True
|
||||
self.disable_cuda_malloc: bool = True
|
||||
@ -352,7 +350,7 @@ class Configuration(dict):
|
||||
|
||||
class EnumAction(argparse.Action):
|
||||
"""
|
||||
Argparse action for handling Enums
|
||||
Argparse action for handling Enums in a case-insensitive manner.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@ -362,23 +360,33 @@ class EnumAction(argparse.Action):
|
||||
# Ensure an Enum subclass is provided
|
||||
if enum_type is None:
|
||||
raise ValueError("type must be assigned an Enum when using EnumAction")
|
||||
enum_type: Any
|
||||
if not issubclass(enum_type, enum.Enum):
|
||||
raise TypeError("type must be an Enum when using EnumAction")
|
||||
|
||||
# Generate choices from the Enum
|
||||
choices = tuple(e.value for e in enum_type)
|
||||
kwargs.setdefault("choices", choices)
|
||||
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
|
||||
|
||||
super(EnumAction, self).__init__(**kwargs)
|
||||
|
||||
self._enum = enum_type
|
||||
|
||||
# Generate choices from the Enum for the help message
|
||||
choices = tuple(e.value for e in enum_type)
|
||||
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
|
||||
|
||||
# We handle choices ourselves for case-insensitivity, so remove it before calling super.
|
||||
if "choices" in kwargs:
|
||||
del kwargs["choices"]
|
||||
|
||||
super(EnumAction, self).__init__(**kwargs)
|
||||
self._choices = choices
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
# Convert value back into an Enum
|
||||
value = self._enum(values)
|
||||
setattr(namespace, self.dest, value)
|
||||
# Convert value back into an Enum, case-insensitively
|
||||
value_lower = values.lower()
|
||||
for member in self._enum:
|
||||
if member.value.lower() == value_lower:
|
||||
setattr(namespace, self.dest, member)
|
||||
return
|
||||
|
||||
# If no match found, raise an error
|
||||
msg = f"invalid choice: {values!r} (choose from {', '.join(self._choices)})"
|
||||
raise argparse.ArgumentError(self, msg)
|
||||
|
||||
|
||||
class ParsedArgs(NamedTuple):
|
||||
@ -401,3 +409,25 @@ class EnhancedConfigArgParser(configargparse.ArgParser):
|
||||
|
||||
namespace, unknown_args = super().parse_known_args(args, namespace, **kwargs)
|
||||
return ParsedArgs(namespace, unknown_args, config_files)
|
||||
|
||||
|
||||
class FlattenAndAppendAction(argparse.Action):
|
||||
"""
|
||||
Custom action to handle comma-separated values and multiple invocations
|
||||
of the same argument, flattening them into a single list.
|
||||
"""
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
items = getattr(namespace, self.dest, None)
|
||||
if items is None:
|
||||
items = []
|
||||
else:
|
||||
# Make a copy if it's not the first time, to avoid modifying the default.
|
||||
items = items[:]
|
||||
|
||||
# 'values' will be a list of strings because of nargs='+'
|
||||
for value in values:
|
||||
# Split comma-separated strings and add them to the list
|
||||
items.extend(item.strip() for item in value.split(','))
|
||||
|
||||
# Set the flattened list back to the namespace.
|
||||
setattr(namespace, self.dest, items)
|
||||
|
||||
126
tests/unit/test_cli_args.py
Normal file
126
tests/unit/test_cli_args.py
Normal file
@ -0,0 +1,126 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from comfy import cli_args
|
||||
from comfy.cli_args_types import LatentPreviewMethod
|
||||
|
||||
# Helper function to parse args and return the Configuration object
|
||||
def _parse_test_args(args_list):
|
||||
parser = cli_args._create_parser()
|
||||
# The `args_parsing=True` makes it use the provided list.
|
||||
with patch.object(parser, 'parse_known_args_with_config_files', return_value=(parser.parse_known_args(args_list)[0], [], [])):
|
||||
return cli_args._parse_args(parser, args_parsing=True)
|
||||
|
||||
@pytest.mark.parametrize("args, expected", [
|
||||
([], []),
|
||||
(['--extra-model-paths-config', 'a'], ['a']),
|
||||
(['--extra-model-paths-config', 'a', '--extra-model-paths-config', 'b'], ['a', 'b']),
|
||||
(['--extra-model-paths-config', 'a,b'], ['a', 'b']),
|
||||
(['--extra-model-paths-config', 'a,b', '--extra-model-paths-config', 'c'], ['a', 'b', 'c']),
|
||||
(['--extra-model-paths-config', ' a , b ', '--extra-model-paths-config', 'c'], ['a', 'b', 'c']),
|
||||
(['--extra-model-paths-config', 'a,b', 'c'], ['a', 'b', 'c']),
|
||||
])
|
||||
def test_extra_model_paths_config(args, expected):
|
||||
"""Test that extra_model_paths_config is parsed correctly."""
|
||||
config = _parse_test_args(args)
|
||||
assert config.extra_model_paths_config == expected
|
||||
|
||||
def test_default_values():
|
||||
"""Test that default values are set correctly when no args are provided."""
|
||||
config = _parse_test_args([])
|
||||
assert config.listen == "127.0.0.1"
|
||||
assert config.port == 8188
|
||||
assert config.auto_launch is False
|
||||
assert config.extra_model_paths_config == []
|
||||
assert config.preview_method == LatentPreviewMethod.Auto
|
||||
assert config.logging_level == 'INFO'
|
||||
assert config.multi_user is False
|
||||
assert config.disable_xformers is False
|
||||
assert config.gpu_only is False
|
||||
assert config.highvram is False
|
||||
assert config.lowvram is False
|
||||
assert config.normalvram is False
|
||||
assert config.novram is False
|
||||
assert config.cpu is False
|
||||
|
||||
def test_listen_and_port():
|
||||
"""Test --listen and --port arguments."""
|
||||
config = _parse_test_args(['--listen', '0.0.0.0', '--port', '8000'])
|
||||
assert config.listen == '0.0.0.0'
|
||||
assert config.port == 8000
|
||||
|
||||
def test_listen_no_arg():
|
||||
"""Test --listen without an argument."""
|
||||
config = _parse_test_args(['--listen'])
|
||||
assert config.listen == '0.0.0.0,::'
|
||||
|
||||
def test_auto_launch_flags():
|
||||
"""Test --auto-launch and --disable-auto-launch flags."""
|
||||
config_auto = _parse_test_args(['--auto-launch'])
|
||||
assert config_auto.auto_launch is True
|
||||
|
||||
config_disable = _parse_test_args(['--disable-auto-launch'])
|
||||
assert config_disable.auto_launch is False
|
||||
|
||||
# Test that --disable-auto-launch overrides --auto-launch if both are present
|
||||
# The order matters, argparse behavior. Last one wins for store_true/false.
|
||||
config_both_1 = _parse_test_args(['--auto-launch', '--disable-auto-launch'])
|
||||
assert config_both_1.auto_launch is False
|
||||
|
||||
config_both_2 = _parse_test_args(['--disable-auto-launch', '--auto-launch'])
|
||||
assert config_both_2.auto_launch is False
|
||||
|
||||
def test_windows_standalone_build_enables_auto_launch():
|
||||
"""Test that --windows-standalone-build enables auto-launch."""
|
||||
config = _parse_test_args(['--windows-standalone-build'])
|
||||
assert config.windows_standalone_build is True
|
||||
assert config.auto_launch is True
|
||||
|
||||
def test_windows_standalone_build_with_disable_auto_launch():
|
||||
"""Test that --disable-auto-launch overrides --windows-standalone-build's auto-launch."""
|
||||
config = _parse_test_args(['--windows-standalone-build', '--disable-auto-launch'])
|
||||
assert config.windows_standalone_build is True
|
||||
assert config.auto_launch is False
|
||||
|
||||
def test_force_fp16_enables_fp16_unet():
|
||||
"""Test that --force-fp16 enables --fp16-unet."""
|
||||
config = _parse_test_args(['--force-fp16'])
|
||||
assert config.force_fp16 is True
|
||||
assert config.fp16_unet is True
|
||||
|
||||
@pytest.mark.parametrize("vram_arg, expected_true_field", [
|
||||
('--gpu-only', 'gpu_only'),
|
||||
('--highvram', 'highvram'),
|
||||
('--normalvram', 'normalvram'),
|
||||
('--lowvram', 'lowvram'),
|
||||
('--novram', 'novram'),
|
||||
('--cpu', 'cpu'),
|
||||
])
|
||||
def test_vram_modes(vram_arg, expected_true_field):
|
||||
"""Test mutually exclusive VRAM mode arguments."""
|
||||
config = _parse_test_args([vram_arg])
|
||||
all_vram_fields = ['gpu_only', 'highvram', 'normalvram', 'lowvram', 'novram', 'cpu']
|
||||
for field in all_vram_fields:
|
||||
if field == expected_true_field:
|
||||
assert getattr(config, field) is True
|
||||
else:
|
||||
assert getattr(config, field) is False
|
||||
|
||||
def test_preview_method():
|
||||
"""Test the --preview-method argument."""
|
||||
config = _parse_test_args(['--preview-method', 'TAESD'])
|
||||
assert config.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
def test_logging_level():
|
||||
"""Test the --logging-level argument."""
|
||||
config = _parse_test_args(['--logging-level', 'debug'])
|
||||
assert config.logging_level == 'DEBUG'
|
||||
|
||||
def test_multi_user():
|
||||
"""Test the --multi-user flag."""
|
||||
config = _parse_test_args(['--multi-user'])
|
||||
assert config.multi_user is True
|
||||
|
||||
def test_disable_xformers():
|
||||
"""Test the --disable-xformers flag."""
|
||||
config = _parse_test_args(['--disable-xformers'])
|
||||
assert config.disable_xformers is True
|
||||
Loading…
Reference in New Issue
Block a user