fix: convert tests to unittest, remove unused import

Amp-Thread-ID: https://ampcode.com/threads/T-019c17ed-fd96-71ed-8055-83a8cd6f8f2b
This commit is contained in:
bymyself 2026-01-31 22:41:33 -08:00
parent 0f259cabdd
commit 6c2223ade9

View File

@ -1,8 +1,8 @@
"""Tests for folder_paths.format_output_filename and get_timestamp functions.""" """Tests for folder_paths.format_output_filename and get_timestamp functions."""
import re
import sys import sys
import os import os
import unittest
# Add the ComfyUI root to the path for imports # Add the ComfyUI root to the path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
@ -10,30 +10,30 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspa
import folder_paths import folder_paths
class TestGetTimestamp: class TestGetTimestamp(unittest.TestCase):
"""Tests for get_timestamp function.""" """Tests for get_timestamp function."""
def test_returns_string(self): def test_returns_string(self):
"""Should return a string.""" """Should return a string."""
result = folder_paths.get_timestamp() result = folder_paths.get_timestamp()
assert isinstance(result, str) self.assertIsInstance(result, str)
def test_format_matches_expected_pattern(self): def test_format_matches_expected_pattern(self):
"""Should return format YYYYMMDD-HHMMSS-ffffff.""" """Should return format YYYYMMDD-HHMMSS-ffffff."""
result = folder_paths.get_timestamp() result = folder_paths.get_timestamp()
# Pattern: 8 digits, hyphen, 6 digits, hyphen, 6 digits # Pattern: 8 digits, hyphen, 6 digits, hyphen, 6 digits
pattern = r"^\d{8}-\d{6}-\d{6}$" pattern = r"^\d{8}-\d{6}-\d{6}$"
assert re.match(pattern, result), f"Timestamp '{result}' does not match expected pattern" self.assertRegex(result, pattern)
def test_is_filesystem_safe(self): def test_is_filesystem_safe(self):
"""Should not contain characters that are unsafe for filenames.""" """Should not contain characters that are unsafe for filenames."""
result = folder_paths.get_timestamp() result = folder_paths.get_timestamp()
unsafe_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|', ' '] unsafe_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|', ' ']
for char in unsafe_chars: for char in unsafe_chars:
assert char not in result, f"Timestamp contains unsafe character: {char}" self.assertNotIn(char, result)
class TestFormatOutputFilename: class TestFormatOutputFilename(unittest.TestCase):
"""Tests for format_output_filename function.""" """Tests for format_output_filename function."""
def test_basic_format(self): def test_basic_format(self):
@ -41,66 +41,43 @@ class TestFormatOutputFilename:
result = folder_paths.format_output_filename("test", 1, "png") result = folder_paths.format_output_filename("test", 1, "png")
# Pattern: test_00001_YYYYMMDD-HHMMSS-ffffff_.png # Pattern: test_00001_YYYYMMDD-HHMMSS-ffffff_.png
pattern = r"^test_00001_\d{8}-\d{6}-\d{6}_\.png$" pattern = r"^test_00001_\d{8}-\d{6}-\d{6}_\.png$"
assert re.match(pattern, result), f"Filename '{result}' does not match expected pattern" self.assertRegex(result, pattern)
def test_counter_padding(self): def test_counter_padding(self):
"""Should pad counter to 5 digits.""" """Should pad counter to 5 digits."""
result = folder_paths.format_output_filename("test", 42, "png") result = folder_paths.format_output_filename("test", 42, "png")
assert "_00042_" in result self.assertIn("_00042_", result)
def test_extension_with_leading_dot(self): def test_extension_with_leading_dot(self):
"""Should handle extension with leading dot.""" """Should handle extension with leading dot."""
result = folder_paths.format_output_filename("test", 1, ".png") result = folder_paths.format_output_filename("test", 1, ".png")
assert result.endswith("_.png") self.assertTrue(result.endswith("_.png"))
assert "..png" not in result self.assertNotIn("..png", result)
def test_extension_without_leading_dot(self): def test_extension_without_leading_dot(self):
"""Should handle extension without leading dot.""" """Should handle extension without leading dot."""
result = folder_paths.format_output_filename("test", 1, "webm") result = folder_paths.format_output_filename("test", 1, "webm")
assert result.endswith("_.webm") self.assertTrue(result.endswith("_.webm"))
def test_batch_num_replacement(self): def test_batch_num_replacement(self):
"""Should replace %batch_num% placeholder.""" """Should replace %batch_num% placeholder."""
result = folder_paths.format_output_filename("test_%batch_num%", 1, "png", batch_num="3") result = folder_paths.format_output_filename("test_%batch_num%", 1, "png", batch_num="3")
assert "test_3_" in result self.assertIn("test_3_", result)
assert "%batch_num%" not in result self.assertNotIn("%batch_num%", result)
def test_custom_timestamp(self): def test_custom_timestamp(self):
"""Should use provided timestamp instead of generating one.""" """Should use provided timestamp instead of generating one."""
custom_ts = "20260101-120000-000000" custom_ts = "20260101-120000-000000"
result = folder_paths.format_output_filename("test", 1, "png", timestamp=custom_ts) result = folder_paths.format_output_filename("test", 1, "png", timestamp=custom_ts)
assert custom_ts in result self.assertIn(custom_ts, result)
def test_different_extensions(self): def test_different_extensions(self):
"""Should work with various extensions.""" """Should work with various extensions."""
extensions = ["png", "webp", "webm", "svg", "glb", "safetensors", "latent"] extensions = ["png", "webp", "webm", "svg", "glb", "safetensors", "latent"]
for ext in extensions: for ext in extensions:
result = folder_paths.format_output_filename("test", 1, ext) result = folder_paths.format_output_filename("test", 1, ext)
assert result.endswith(f"_.{ext}") self.assertTrue(result.endswith(f"_.{ext}"))
if __name__ == "__main__": if __name__ == "__main__":
# Simple test runner unittest.main()
import traceback
test_classes = [TestGetTimestamp, TestFormatOutputFilename]
passed = 0
failed = 0
for test_class in test_classes:
instance = test_class()
for method_name in dir(instance):
if method_name.startswith("test_"):
try:
getattr(instance, method_name)()
print(f"{test_class.__name__}.{method_name}")
passed += 1
except AssertionError as e:
print(f"{test_class.__name__}.{method_name}: {e}")
failed += 1
except Exception as e:
print(f"{test_class.__name__}.{method_name}: {traceback.format_exc()}")
failed += 1
print(f"\n{passed} passed, {failed} failed")
sys.exit(0 if failed == 0 else 1)