From f1a096b3e1c7199285a0b5e8f15336189f94ef3e Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Fri, 16 Aug 2024 15:46:11 -0700 Subject: [PATCH] Merges new frontend - fixes bfloat16 on cpu to numpy issues - extensions should go into comfy/web/extensions/javascript --- comfy/cmd/server.py | 11 ++- comfy/model_downloader.py | 2 + comfy/model_filemanager/download_models.py | 66 +++++++++-------- comfy/nodes/base_nodes.py | 8 ++- comfy/web/extensions/javascript/language.js | 62 ++++++++++++++++ .../{core => javascript}/textExtraOutput.js | 0 comfy_extras/nodes/nodes_apply_color_map.py | 2 +- comfy_extras/nodes/nodes_images.py | 4 +- comfy_extras/nodes/nodes_open_api.py | 2 +- comfy_extras/nodes/nodes_skeletonize.py | 2 +- requirements.txt | 1 + .../download_models_test.py | 71 ++++++++++++------- 12 files changed, 162 insertions(+), 69 deletions(-) create mode 100644 comfy/web/extensions/javascript/language.js rename comfy/web/extensions/{core => javascript}/textExtraOutput.js (100%) diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index fe972a044..c0f352701 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -172,14 +172,12 @@ class PromptServer(ExecutorToClientProgress): @routes.get("/extensions") async def get_extensions(request): - files = glob.glob(os.path.join( - glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True) + files = glob.glob(os.path.join(glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True) extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)) for name, dir in self.nodes.EXTENSION_WEB_DIRS.items(): files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True) - extensions.extend(list(map(lambda f: "/extensions/" + quote( - name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files))) + extensions.extend(list(map(lambda f: "/extensions/" + quote(name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files))) return web.json_response(extensions) @@ -444,8 +442,9 @@ class PromptServer(ExecutorToClientProgress): info = {} info['input'] = obj_class.INPUT_TYPES() info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} - info['output'] = obj_class.RETURN_TYPES - info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) + _return_types = ["STRING" if isinstance(rt, list) and rt == [] else rt for rt in obj_class.RETURN_TYPES] + info['output'] = _return_types + info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(_return_types) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] info['name'] = node_class info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 4f3bb5e07..c20c83246 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -397,6 +397,8 @@ KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([ KNOWN_APPROX_VAES: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"), HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors"), + HuggingFile("madebyollin/taef1", "diffusion_pytorch_model.safetensors", save_with_filename="taef1_decoder.safetensors"), + HuggingFile("madebyollin/taesd3", "diffusion_pytorch_model.safetensors", save_with_filename="taesd3_decoder.safetensors"), ], folder_name="vae_approx") KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([ diff --git a/comfy/model_filemanager/download_models.py b/comfy/model_filemanager/download_models.py index 712d59328..9be342908 100644 --- a/comfy/model_filemanager/download_models.py +++ b/comfy/model_filemanager/download_models.py @@ -1,14 +1,17 @@ from __future__ import annotations -import aiohttp -import os -import traceback + import logging -from folder_paths import models_dir +import os import re -from typing import Callable, Any, Optional, Awaitable, Dict -from enum import Enum import time +import traceback from dataclasses import dataclass +from enum import Enum +from typing import Callable, Any, Optional, Awaitable, Dict + +import aiohttp + +from ..cmd.folder_paths import models_dir class DownloadStatusType(Enum): @@ -17,6 +20,7 @@ class DownloadStatusType(Enum): COMPLETED = "completed" ERROR = "error" + @dataclass class DownloadModelStatus(): status: str @@ -29,7 +33,7 @@ class DownloadModelStatus(): self.progress_percentage = progress_percentage self.message = message self.already_existed = already_existed - + def to_dict(self) -> Dict[str, Any]: return { "status": self.status, @@ -38,9 +42,10 @@ class DownloadModelStatus(): "already_existed": self.already_existed } + async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], - model_name: str, - model_url: str, + model_name: str, + model_url: str, model_sub_directory: str, progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], progress_interval: float = 1.0) -> DownloadModelStatus: @@ -65,17 +70,17 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht """ if not validate_model_subdirectory(model_sub_directory): return DownloadModelStatus( - DownloadStatusType.ERROR, + DownloadStatusType.ERROR, 0, - "Invalid model subdirectory", + "Invalid model subdirectory", False ) if not validate_filename(model_name): return DownloadModelStatus( - DownloadStatusType.ERROR, + DownloadStatusType.ERROR, 0, - "Invalid model name", + "Invalid model name", False ) @@ -101,7 +106,7 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht except Exception as e: logging.error(f"Error in downloading model: {e}") return await handle_download_error(e, model_name, progress_callback, relative_path) - + def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]: full_model_dir = os.path.join(models_base_dir, model_directory) @@ -114,13 +119,13 @@ def create_model_path(model_name: str, model_directory: str, models_base_dir: st if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir: raise Exception(f"Invalid model directory: {model_directory}/{model_name}") - relative_path = '/'.join([model_directory, model_name]) return file_path, relative_path -async def check_file_exists(file_path: str, - model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], + +async def check_file_exists(file_path: str, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], relative_path: str) -> Optional[DownloadModelStatus]: if os.path.exists(file_path): status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True) @@ -129,11 +134,11 @@ async def check_file_exists(file_path: str, return None -async def track_download_progress(response: aiohttp.ClientResponse, - file_path: str, - model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], - relative_path: str, +async def track_download_progress(response: aiohttp.ClientResponse, + file_path: str, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], + relative_path: str, interval: float = 1.0) -> DownloadModelStatus: try: total_size = int(response.headers.get('Content-Length', 0)) @@ -156,12 +161,12 @@ async def track_download_progress(response: aiohttp.ClientResponse, break f.write(chunk) downloaded += len(chunk) - + if time.time() - last_update_time >= interval: await update_progress() await update_progress() - + logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False) await progress_callback(relative_path, status) @@ -172,15 +177,17 @@ async def track_download_progress(response: aiohttp.ClientResponse, logging.error(traceback.format_exc()) return await handle_download_error(e, model_name, progress_callback, relative_path) -async def handle_download_error(e: Exception, - model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Any], + +async def handle_download_error(e: Exception, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Any], relative_path: str) -> DownloadModelStatus: error_message = f"Error downloading {model_name}: {str(e)}" status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) await progress_callback(relative_path, status) return status + def validate_model_subdirectory(model_subdirectory: str) -> bool: """ Validate that the model subdirectory is safe to install into. @@ -204,7 +211,8 @@ def validate_model_subdirectory(model_subdirectory: str) -> bool: return True -def validate_filename(filename: str)-> bool: + +def validate_filename(filename: str) -> bool: """ Validate a filename to ensure it's safe and doesn't contain any path traversal attempts. diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index a6ec70136..0e7a67feb 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -24,6 +24,7 @@ from .. import model_management from ..cli_args import args from ..cmd import folder_paths, latent_preview +from ..component_model.tensor_types import RGBImage from ..execution_context import current_execution_context from ..images import open_image from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS @@ -718,8 +719,8 @@ class VAELoader: sd_["vae_scale"] = torch.tensor(1.5305) sd_["vae_shift"] = torch.tensor(0.0609) elif name == "taef1": - sd["vae_scale"] = torch.tensor(0.3611) - sd["vae_shift"] = torch.tensor(0.1159) + sd_["vae_scale"] = torch.tensor(0.3611) + sd_["vae_shift"] = torch.tensor(0.1159) return sd_ @classmethod @@ -1492,8 +1493,9 @@ class SaveImage: filename_prefix += self.prefix_append full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) results = list() + image: RGBImage for (batch_number, image) in enumerate(images): - i = 255. * image.cpu().numpy() + i = 255. * image.float().cpu().numpy() img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) metadata = None if not args.disable_metadata: diff --git a/comfy/web/extensions/javascript/language.js b/comfy/web/extensions/javascript/language.js new file mode 100644 index 000000000..2848302b6 --- /dev/null +++ b/comfy/web/extensions/javascript/language.js @@ -0,0 +1,62 @@ +import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; +import { ComfyWidgets } from "../../scripts/widgets.js"; + +const tokenPreviewWidgetName = "__tokens"; + +class TokenProgressHandler { + constructor() { + this.nodeOutputs = {}; + this.initEventListeners(); + } + + initEventListeners() { + api.addEventListener("executing", ({ detail }) => { + if (!detail) { + return; + } + const nodeId = detail; + if (!this.nodeOutputs[nodeId]) { + this.nodeOutputs[nodeId] = {}; + } + this.nodeOutputs[nodeId].tokens = null; + }); + + api.addEventListener("progress", ({ detail }) => { + const nodeId = detail.node; + if (!this.nodeOutputs[nodeId]) { + this.nodeOutputs[nodeId] = {}; + } + if (detail.output && detail.output.next_token) { + if (!this.nodeOutputs[nodeId].tokens) { + this.nodeOutputs[nodeId].tokens = ""; + } + this.nodeOutputs[nodeId].tokens += detail.output.next_token; + this.updateTokenWidget(nodeId, this.nodeOutputs[nodeId].tokens); + } + app.graph.setDirtyCanvas(true, false); + }); + } + + updateTokenWidget(nodeId, tokens) { + const node = app.graph.getNodeById(nodeId); + if (node && node.widgets) { + let widget = node.widgets.find((w) => w.name === tokenPreviewWidgetName); + + if (!widget) { + widget = ComfyWidgets["STRING"](node, tokenPreviewWidgetName, ["STRING", { multiline: true }], app).widget; + widget.inputEl.readOnly = true; + widget.inputEl.style.opacity = 0.7; + } + widget.value = tokens; + app.graph.setDirtyCanvas(true, false); + } + } +} + +app.registerExtension({ + name: "Comfy.TokenProgress", + setup() { + this.tokenProgressHandler = new TokenProgressHandler(); + }, +}); diff --git a/comfy/web/extensions/core/textExtraOutput.js b/comfy/web/extensions/javascript/textExtraOutput.js similarity index 100% rename from comfy/web/extensions/core/textExtraOutput.js rename to comfy/web/extensions/javascript/textExtraOutput.js diff --git a/comfy_extras/nodes/nodes_apply_color_map.py b/comfy_extras/nodes/nodes_apply_color_map.py index 36627b8cf..7daa519a8 100644 --- a/comfy_extras/nodes/nodes_apply_color_map.py +++ b/comfy_extras/nodes/nodes_apply_color_map.py @@ -70,7 +70,7 @@ class ImageApplyColorMap(CustomNode): cv2_colormap = getattr(cv2, colormap) colored_image = cv2.applyColorMap(normalized_depth_uint8, cv2_colormap) # pylint: disable=no-member colored_image_rgb = cv2.cvtColor(colored_image, cv2.COLOR_BGR2RGB) # pylint: disable=no-member - rgb_tensor = torch.tensor(colored_image_rgb) * 1.0 / 255.0 + rgb_tensor = torch.tensor(colored_image_rgb).float() * 1.0 / 255.0 colored_images.append(rgb_tensor) return torch.stack(colored_images), diff --git a/comfy_extras/nodes/nodes_images.py b/comfy_extras/nodes/nodes_images.py index eade3605d..02b5d5e9c 100644 --- a/comfy_extras/nodes/nodes_images.py +++ b/comfy_extras/nodes/nodes_images.py @@ -101,7 +101,7 @@ class SaveAnimatedWEBP: results = list() pil_images = [] for image in images: - i = 255. * image.cpu().numpy() + i = 255. * image.float().cpu().numpy() img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) pil_images.append(img) @@ -162,7 +162,7 @@ class SaveAnimatedPNG: results = list() pil_images = [] for image in images: - i = 255. * image.cpu().numpy() + i = 255. * image.float().cpu().numpy() img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) pil_images.append(img) diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py index 9b848c253..7d71b8648 100644 --- a/comfy_extras/nodes/nodes_open_api.py +++ b/comfy_extras/nodes/nodes_open_api.py @@ -197,7 +197,7 @@ class HashImage(CustomNode): def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult: def process_image(image: Tensor) -> str: - image_as_numpy_array: np.ndarray = 255. * image.cpu().numpy() + image_as_numpy_array: np.ndarray = 255. * image.float().cpu().numpy() image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8)) data = image_as_numpy_array.data try: diff --git a/comfy_extras/nodes/nodes_skeletonize.py b/comfy_extras/nodes/nodes_skeletonize.py index a99bdff18..785523932 100644 --- a/comfy_extras/nodes/nodes_skeletonize.py +++ b/comfy_extras/nodes/nodes_skeletonize.py @@ -31,7 +31,7 @@ class SkeletonizeThin: results = [] for img in binary: - img_np = img.squeeze().cpu().numpy() + img_np = img.squeeze().float().cpu().numpy() if use_skeletonize: result = skeletonize(img_np) diff --git a/requirements.txt b/requirements.txt index 1e83cd36e..a85bf8912 100644 --- a/requirements.txt +++ b/requirements.txt @@ -62,3 +62,4 @@ networkx>=2.6.3 joblib jaxtyping spandrel_extra_arches +ml_dtypes \ No newline at end of file diff --git a/tests/unit/prompt_server_test/download_models_test.py b/tests/unit/prompt_server_test/download_models_test.py index 66150a468..ee6221fd8 100644 --- a/tests/unit/prompt_server_test/download_models_test.py +++ b/tests/unit/prompt_server_test/download_models_test.py @@ -1,16 +1,21 @@ -import pytest -import aiohttp -from aiohttp import ClientResponse import itertools -import os +import os from unittest.mock import AsyncMock, patch, MagicMock -from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename + +import aiohttp +import pytest +from aiohttp import ClientResponse + +from comfy.model_filemanager import download_model, validate_model_subdirectory, track_download_progress, \ + create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename + class AsyncIteratorMock: """ A mock class that simulates an asynchronous iterator. This is used to mimic the behavior of aiohttp's content iterator. """ + def __init__(self, seq): # Convert the input sequence into an iterator self.iter = iter(seq) @@ -27,11 +32,13 @@ class AsyncIteratorMock: # This is the asynchronous equivalent of StopIteration raise StopAsyncIteration + class ContentMock: """ A mock class that simulates the content attribute of an aiohttp ClientResponse. This class provides the iter_chunked method which returns an async iterator of chunks. """ + def __init__(self, chunks): # Store the chunks that will be returned by the iterator self.chunks = chunks @@ -41,6 +48,7 @@ class ContentMock: # For simplicity in testing, we ignore chunk_size and just return our predefined chunks return AsyncIteratorMock(self.chunks) + @pytest.mark.asyncio async def test_download_model_success(): mock_response = AsyncMock(spec=aiohttp.ClientResponse) @@ -59,10 +67,10 @@ async def test_download_model_success(): mock_open.return_value.__enter__.return_value = mock_file time_values = itertools.count(0, 0.1) - with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \ - patch('model_filemanager.check_file_exists', return_value=None), \ - patch('builtins.open', mock_open), \ - patch('time.time', side_effect=time_values): # Simulate time passing + with patch('comfy.model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \ + patch('comfy.model_filemanager.check_file_exists', return_value=None), \ + patch('builtins.open', mock_open), \ + patch('time.time', side_effect=time_values): # Simulate time passing result = await download_model( mock_make_request, @@ -80,7 +88,7 @@ async def test_download_model_success(): # Check progress callback calls assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion - + # Check initial call mock_progress_callback.assert_any_call( 'checkpoints/model.sft', @@ -101,6 +109,7 @@ async def test_download_model_success(): # Verify request was made mock_make_request.assert_called_once_with('http://example.com/model.sft') + @pytest.mark.asyncio async def test_download_model_url_request_failure(): # Mock dependencies @@ -110,9 +119,9 @@ async def test_download_model_url_request_failure(): mock_progress_callback = AsyncMock() # Mock the create_model_path function - with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')): + with patch('comfy.model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')): # Mock the check_file_exists function to return None (file doesn't exist) - with patch('model_filemanager.check_file_exists', return_value=None): + with patch('comfy.model_filemanager.check_file_exists', return_value=None): # Call the function result = await download_model( mock_get, @@ -151,13 +160,12 @@ async def test_download_model_url_request_failure(): # Verify that the get method was called with the correct URL mock_get.assert_called_once_with('http://example.com/model.safetensors') + @pytest.mark.asyncio async def test_download_model_invalid_model_subdirectory(): - mock_make_request = AsyncMock() mock_progress_callback = AsyncMock() - result = await download_model( mock_make_request, 'model.sft', @@ -177,12 +185,12 @@ async def test_download_model_invalid_model_subdirectory(): def test_create_model_path(tmp_path, monkeypatch): mock_models_dir = tmp_path / "models" monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir)) - + model_name = "test_model.sft" model_directory = "test_dir" - + file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir) - + assert file_path == str(mock_models_dir / model_directory / model_name) assert relative_path == f"{model_directory}/{model_name}" assert os.path.exists(os.path.dirname(file_path)) @@ -192,32 +200,34 @@ def test_create_model_path(tmp_path, monkeypatch): async def test_check_file_exists_when_file_exists(tmp_path): file_path = tmp_path / "existing_model.sft" file_path.touch() # Create an empty file - + mock_callback = AsyncMock() - + result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft") - + assert result is not None assert result.status == "completed" assert result.message == "existing_model.sft already exists" assert result.already_existed is True - + mock_callback.assert_called_once_with( "test/existing_model.sft", DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True) ) + @pytest.mark.asyncio async def test_check_file_exists_when_file_does_not_exist(tmp_path): file_path = tmp_path / "non_existing_model.sft" - + mock_callback = AsyncMock() - + result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft") - + assert result is None mock_callback.assert_not_called() + @pytest.mark.asyncio async def test_track_download_progress_no_content_length(): mock_response = AsyncMock(spec=aiohttp.ClientResponse) @@ -240,6 +250,7 @@ async def test_track_download_progress_no_content_length(): DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False) ) + @pytest.mark.asyncio async def test_track_download_progress_interval(): mock_response = AsyncMock(spec=aiohttp.ClientResponse) @@ -254,7 +265,7 @@ async def test_track_download_progress_interval(): mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks with patch('builtins.open', mock_open), \ - patch('time.time', mock_time): + patch('time.time', mock_time): await track_download_progress( mock_response, '/mock/path/model.sft', 'model.sft', mock_callback, 'models/model.sft', interval=1.0 @@ -279,31 +290,39 @@ async def test_track_download_progress_interval(): assert last_call[0][1].status == "completed" assert last_call[0][1].progress_percentage == 100 + def test_valid_subdirectory(): assert validate_model_subdirectory("valid-model123") is True + def test_subdirectory_too_long(): assert validate_model_subdirectory("a" * 51) is False + def test_subdirectory_with_double_dots(): assert validate_model_subdirectory("model/../unsafe") is False + def test_subdirectory_with_slash(): assert validate_model_subdirectory("model/unsafe") is False + def test_subdirectory_with_special_characters(): assert validate_model_subdirectory("model@unsafe") is False + def test_subdirectory_with_underscore_and_dash(): assert validate_model_subdirectory("valid_model-name") is True + def test_empty_subdirectory(): assert validate_model_subdirectory("") is False + @pytest.mark.parametrize("filename, expected", [ ("valid_model.safetensors", True), ("valid_model.sft", True), - ("valid model.safetensors", True), # Test with space + ("valid model.safetensors", True), # Test with space ("UPPERCASE_MODEL.SAFETENSORS", True), ("model_with.multiple.dots.pt", False), ("", False), # Empty string