mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Merges new frontend
- fixes bfloat16 on cpu to numpy issues - extensions should go into comfy/web/extensions/javascript
This commit is contained in:
parent
940fe6e3af
commit
f1a096b3e1
@ -172,14 +172,12 @@ class PromptServer(ExecutorToClientProgress):
|
||||
|
||||
@routes.get("/extensions")
|
||||
async def get_extensions(request):
|
||||
files = glob.glob(os.path.join(
|
||||
glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
|
||||
files = glob.glob(os.path.join(glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
|
||||
extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
|
||||
|
||||
for name, dir in self.nodes.EXTENSION_WEB_DIRS.items():
|
||||
files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True)
|
||||
extensions.extend(list(map(lambda f: "/extensions/" + quote(
|
||||
name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
|
||||
extensions.extend(list(map(lambda f: "/extensions/" + quote(name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
|
||||
|
||||
return web.json_response(extensions)
|
||||
|
||||
@ -444,8 +442,9 @@ class PromptServer(ExecutorToClientProgress):
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||
_return_types = ["STRING" if isinstance(rt, list) and rt == [] else rt for rt in obj_class.RETURN_TYPES]
|
||||
info['output'] = _return_types
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(_return_types)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
info['name'] = node_class
|
||||
info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||
|
||||
@ -397,6 +397,8 @@ KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
|
||||
KNOWN_APPROX_VAES: Final[KnownDownloadables] = KnownDownloadables([
|
||||
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"),
|
||||
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors"),
|
||||
HuggingFile("madebyollin/taef1", "diffusion_pytorch_model.safetensors", save_with_filename="taef1_decoder.safetensors"),
|
||||
HuggingFile("madebyollin/taesd3", "diffusion_pytorch_model.safetensors", save_with_filename="taesd3_decoder.safetensors"),
|
||||
], folder_name="vae_approx")
|
||||
|
||||
KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([
|
||||
|
||||
@ -1,14 +1,17 @@
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import os
|
||||
import traceback
|
||||
|
||||
import logging
|
||||
from folder_paths import models_dir
|
||||
import os
|
||||
import re
|
||||
from typing import Callable, Any, Optional, Awaitable, Dict
|
||||
from enum import Enum
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable, Any, Optional, Awaitable, Dict
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ..cmd.folder_paths import models_dir
|
||||
|
||||
|
||||
class DownloadStatusType(Enum):
|
||||
@ -17,6 +20,7 @@ class DownloadStatusType(Enum):
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadModelStatus():
|
||||
status: str
|
||||
@ -29,7 +33,7 @@ class DownloadModelStatus():
|
||||
self.progress_percentage = progress_percentage
|
||||
self.message = message
|
||||
self.already_existed = already_existed
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"status": self.status,
|
||||
@ -38,9 +42,10 @@ class DownloadModelStatus():
|
||||
"already_existed": self.already_existed
|
||||
}
|
||||
|
||||
|
||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||
model_name: str,
|
||||
model_url: str,
|
||||
model_name: str,
|
||||
model_url: str,
|
||||
model_sub_directory: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||
progress_interval: float = 1.0) -> DownloadModelStatus:
|
||||
@ -65,17 +70,17 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
|
||||
"""
|
||||
if not validate_model_subdirectory(model_sub_directory):
|
||||
return DownloadModelStatus(
|
||||
DownloadStatusType.ERROR,
|
||||
DownloadStatusType.ERROR,
|
||||
0,
|
||||
"Invalid model subdirectory",
|
||||
"Invalid model subdirectory",
|
||||
False
|
||||
)
|
||||
|
||||
if not validate_filename(model_name):
|
||||
return DownloadModelStatus(
|
||||
DownloadStatusType.ERROR,
|
||||
DownloadStatusType.ERROR,
|
||||
0,
|
||||
"Invalid model name",
|
||||
"Invalid model name",
|
||||
False
|
||||
)
|
||||
|
||||
@ -101,7 +106,7 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
|
||||
except Exception as e:
|
||||
logging.error(f"Error in downloading model: {e}")
|
||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||
|
||||
|
||||
|
||||
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
|
||||
full_model_dir = os.path.join(models_base_dir, model_directory)
|
||||
@ -114,13 +119,13 @@ def create_model_path(model_name: str, model_directory: str, models_base_dir: st
|
||||
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
||||
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
|
||||
|
||||
|
||||
relative_path = '/'.join([model_directory, model_name])
|
||||
return file_path, relative_path
|
||||
|
||||
async def check_file_exists(file_path: str,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||
|
||||
async def check_file_exists(file_path: str,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||
relative_path: str) -> Optional[DownloadModelStatus]:
|
||||
if os.path.exists(file_path):
|
||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
||||
@ -129,11 +134,11 @@ async def check_file_exists(file_path: str,
|
||||
return None
|
||||
|
||||
|
||||
async def track_download_progress(response: aiohttp.ClientResponse,
|
||||
file_path: str,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||
relative_path: str,
|
||||
async def track_download_progress(response: aiohttp.ClientResponse,
|
||||
file_path: str,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||
relative_path: str,
|
||||
interval: float = 1.0) -> DownloadModelStatus:
|
||||
try:
|
||||
total_size = int(response.headers.get('Content-Length', 0))
|
||||
@ -156,12 +161,12 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
||||
break
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
|
||||
|
||||
if time.time() - last_update_time >= interval:
|
||||
await update_progress()
|
||||
|
||||
await update_progress()
|
||||
|
||||
|
||||
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
||||
await progress_callback(relative_path, status)
|
||||
@ -172,15 +177,17 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
||||
logging.error(traceback.format_exc())
|
||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||
|
||||
async def handle_download_error(e: Exception,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
||||
|
||||
async def handle_download_error(e: Exception,
|
||||
model_name: str,
|
||||
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
||||
relative_path: str) -> DownloadModelStatus:
|
||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||
await progress_callback(relative_path, status)
|
||||
return status
|
||||
|
||||
|
||||
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
||||
"""
|
||||
Validate that the model subdirectory is safe to install into.
|
||||
@ -204,7 +211,8 @@ def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
||||
|
||||
return True
|
||||
|
||||
def validate_filename(filename: str)-> bool:
|
||||
|
||||
def validate_filename(filename: str) -> bool:
|
||||
"""
|
||||
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ from .. import model_management
|
||||
from ..cli_args import args
|
||||
|
||||
from ..cmd import folder_paths, latent_preview
|
||||
from ..component_model.tensor_types import RGBImage
|
||||
from ..execution_context import current_execution_context
|
||||
from ..images import open_image
|
||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
|
||||
@ -718,8 +719,8 @@ class VAELoader:
|
||||
sd_["vae_scale"] = torch.tensor(1.5305)
|
||||
sd_["vae_shift"] = torch.tensor(0.0609)
|
||||
elif name == "taef1":
|
||||
sd["vae_scale"] = torch.tensor(0.3611)
|
||||
sd["vae_shift"] = torch.tensor(0.1159)
|
||||
sd_["vae_scale"] = torch.tensor(0.3611)
|
||||
sd_["vae_shift"] = torch.tensor(0.1159)
|
||||
return sd_
|
||||
|
||||
@classmethod
|
||||
@ -1492,8 +1493,9 @@ class SaveImage:
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results = list()
|
||||
image: RGBImage
|
||||
for (batch_number, image) in enumerate(images):
|
||||
i = 255. * image.cpu().numpy()
|
||||
i = 255. * image.float().cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
metadata = None
|
||||
if not args.disable_metadata:
|
||||
|
||||
62
comfy/web/extensions/javascript/language.js
vendored
Normal file
62
comfy/web/extensions/javascript/language.js
vendored
Normal file
@ -0,0 +1,62 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { api } from "../../scripts/api.js";
|
||||
import { ComfyWidgets } from "../../scripts/widgets.js";
|
||||
|
||||
const tokenPreviewWidgetName = "__tokens";
|
||||
|
||||
class TokenProgressHandler {
|
||||
constructor() {
|
||||
this.nodeOutputs = {};
|
||||
this.initEventListeners();
|
||||
}
|
||||
|
||||
initEventListeners() {
|
||||
api.addEventListener("executing", ({ detail }) => {
|
||||
if (!detail) {
|
||||
return;
|
||||
}
|
||||
const nodeId = detail;
|
||||
if (!this.nodeOutputs[nodeId]) {
|
||||
this.nodeOutputs[nodeId] = {};
|
||||
}
|
||||
this.nodeOutputs[nodeId].tokens = null;
|
||||
});
|
||||
|
||||
api.addEventListener("progress", ({ detail }) => {
|
||||
const nodeId = detail.node;
|
||||
if (!this.nodeOutputs[nodeId]) {
|
||||
this.nodeOutputs[nodeId] = {};
|
||||
}
|
||||
if (detail.output && detail.output.next_token) {
|
||||
if (!this.nodeOutputs[nodeId].tokens) {
|
||||
this.nodeOutputs[nodeId].tokens = "";
|
||||
}
|
||||
this.nodeOutputs[nodeId].tokens += detail.output.next_token;
|
||||
this.updateTokenWidget(nodeId, this.nodeOutputs[nodeId].tokens);
|
||||
}
|
||||
app.graph.setDirtyCanvas(true, false);
|
||||
});
|
||||
}
|
||||
|
||||
updateTokenWidget(nodeId, tokens) {
|
||||
const node = app.graph.getNodeById(nodeId);
|
||||
if (node && node.widgets) {
|
||||
let widget = node.widgets.find((w) => w.name === tokenPreviewWidgetName);
|
||||
|
||||
if (!widget) {
|
||||
widget = ComfyWidgets["STRING"](node, tokenPreviewWidgetName, ["STRING", { multiline: true }], app).widget;
|
||||
widget.inputEl.readOnly = true;
|
||||
widget.inputEl.style.opacity = 0.7;
|
||||
}
|
||||
widget.value = tokens;
|
||||
app.graph.setDirtyCanvas(true, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.TokenProgress",
|
||||
setup() {
|
||||
this.tokenProgressHandler = new TokenProgressHandler();
|
||||
},
|
||||
});
|
||||
@ -70,7 +70,7 @@ class ImageApplyColorMap(CustomNode):
|
||||
cv2_colormap = getattr(cv2, colormap)
|
||||
colored_image = cv2.applyColorMap(normalized_depth_uint8, cv2_colormap) # pylint: disable=no-member
|
||||
colored_image_rgb = cv2.cvtColor(colored_image, cv2.COLOR_BGR2RGB) # pylint: disable=no-member
|
||||
rgb_tensor = torch.tensor(colored_image_rgb) * 1.0 / 255.0
|
||||
rgb_tensor = torch.tensor(colored_image_rgb).float() * 1.0 / 255.0
|
||||
colored_images.append(rgb_tensor)
|
||||
|
||||
return torch.stack(colored_images),
|
||||
|
||||
@ -101,7 +101,7 @@ class SaveAnimatedWEBP:
|
||||
results = list()
|
||||
pil_images = []
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
i = 255. * image.float().cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
pil_images.append(img)
|
||||
|
||||
@ -162,7 +162,7 @@ class SaveAnimatedPNG:
|
||||
results = list()
|
||||
pil_images = []
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
i = 255. * image.float().cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
pil_images.append(img)
|
||||
|
||||
|
||||
@ -197,7 +197,7 @@ class HashImage(CustomNode):
|
||||
|
||||
def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult:
|
||||
def process_image(image: Tensor) -> str:
|
||||
image_as_numpy_array: np.ndarray = 255. * image.cpu().numpy()
|
||||
image_as_numpy_array: np.ndarray = 255. * image.float().cpu().numpy()
|
||||
image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8))
|
||||
data = image_as_numpy_array.data
|
||||
try:
|
||||
|
||||
@ -31,7 +31,7 @@ class SkeletonizeThin:
|
||||
|
||||
results = []
|
||||
for img in binary:
|
||||
img_np = img.squeeze().cpu().numpy()
|
||||
img_np = img.squeeze().float().cpu().numpy()
|
||||
|
||||
if use_skeletonize:
|
||||
result = skeletonize(img_np)
|
||||
|
||||
@ -62,3 +62,4 @@ networkx>=2.6.3
|
||||
joblib
|
||||
jaxtyping
|
||||
spandrel_extra_arches
|
||||
ml_dtypes
|
||||
@ -1,16 +1,21 @@
|
||||
import pytest
|
||||
import aiohttp
|
||||
from aiohttp import ClientResponse
|
||||
import itertools
|
||||
import os
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from aiohttp import ClientResponse
|
||||
|
||||
from comfy.model_filemanager import download_model, validate_model_subdirectory, track_download_progress, \
|
||||
create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
||||
|
||||
|
||||
class AsyncIteratorMock:
|
||||
"""
|
||||
A mock class that simulates an asynchronous iterator.
|
||||
This is used to mimic the behavior of aiohttp's content iterator.
|
||||
"""
|
||||
|
||||
def __init__(self, seq):
|
||||
# Convert the input sequence into an iterator
|
||||
self.iter = iter(seq)
|
||||
@ -27,11 +32,13 @@ class AsyncIteratorMock:
|
||||
# This is the asynchronous equivalent of StopIteration
|
||||
raise StopAsyncIteration
|
||||
|
||||
|
||||
class ContentMock:
|
||||
"""
|
||||
A mock class that simulates the content attribute of an aiohttp ClientResponse.
|
||||
This class provides the iter_chunked method which returns an async iterator of chunks.
|
||||
"""
|
||||
|
||||
def __init__(self, chunks):
|
||||
# Store the chunks that will be returned by the iterator
|
||||
self.chunks = chunks
|
||||
@ -41,6 +48,7 @@ class ContentMock:
|
||||
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
|
||||
return AsyncIteratorMock(self.chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_success():
|
||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||
@ -59,10 +67,10 @@ async def test_download_model_success():
|
||||
mock_open.return_value.__enter__.return_value = mock_file
|
||||
time_values = itertools.count(0, 0.1)
|
||||
|
||||
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
|
||||
patch('model_filemanager.check_file_exists', return_value=None), \
|
||||
patch('builtins.open', mock_open), \
|
||||
patch('time.time', side_effect=time_values): # Simulate time passing
|
||||
with patch('comfy.model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
|
||||
patch('comfy.model_filemanager.check_file_exists', return_value=None), \
|
||||
patch('builtins.open', mock_open), \
|
||||
patch('time.time', side_effect=time_values): # Simulate time passing
|
||||
|
||||
result = await download_model(
|
||||
mock_make_request,
|
||||
@ -80,7 +88,7 @@ async def test_download_model_success():
|
||||
|
||||
# Check progress callback calls
|
||||
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
|
||||
|
||||
|
||||
# Check initial call
|
||||
mock_progress_callback.assert_any_call(
|
||||
'checkpoints/model.sft',
|
||||
@ -101,6 +109,7 @@ async def test_download_model_success():
|
||||
# Verify request was made
|
||||
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_url_request_failure():
|
||||
# Mock dependencies
|
||||
@ -110,9 +119,9 @@ async def test_download_model_url_request_failure():
|
||||
mock_progress_callback = AsyncMock()
|
||||
|
||||
# Mock the create_model_path function
|
||||
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
|
||||
with patch('comfy.model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
|
||||
# Mock the check_file_exists function to return None (file doesn't exist)
|
||||
with patch('model_filemanager.check_file_exists', return_value=None):
|
||||
with patch('comfy.model_filemanager.check_file_exists', return_value=None):
|
||||
# Call the function
|
||||
result = await download_model(
|
||||
mock_get,
|
||||
@ -151,13 +160,12 @@ async def test_download_model_url_request_failure():
|
||||
# Verify that the get method was called with the correct URL
|
||||
mock_get.assert_called_once_with('http://example.com/model.safetensors')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_model_invalid_model_subdirectory():
|
||||
|
||||
mock_make_request = AsyncMock()
|
||||
mock_progress_callback = AsyncMock()
|
||||
|
||||
|
||||
result = await download_model(
|
||||
mock_make_request,
|
||||
'model.sft',
|
||||
@ -177,12 +185,12 @@ async def test_download_model_invalid_model_subdirectory():
|
||||
def test_create_model_path(tmp_path, monkeypatch):
|
||||
mock_models_dir = tmp_path / "models"
|
||||
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
|
||||
|
||||
|
||||
model_name = "test_model.sft"
|
||||
model_directory = "test_dir"
|
||||
|
||||
|
||||
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
|
||||
|
||||
|
||||
assert file_path == str(mock_models_dir / model_directory / model_name)
|
||||
assert relative_path == f"{model_directory}/{model_name}"
|
||||
assert os.path.exists(os.path.dirname(file_path))
|
||||
@ -192,32 +200,34 @@ def test_create_model_path(tmp_path, monkeypatch):
|
||||
async def test_check_file_exists_when_file_exists(tmp_path):
|
||||
file_path = tmp_path / "existing_model.sft"
|
||||
file_path.touch() # Create an empty file
|
||||
|
||||
|
||||
mock_callback = AsyncMock()
|
||||
|
||||
|
||||
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
|
||||
|
||||
|
||||
assert result is not None
|
||||
assert result.status == "completed"
|
||||
assert result.message == "existing_model.sft already exists"
|
||||
assert result.already_existed is True
|
||||
|
||||
|
||||
mock_callback.assert_called_once_with(
|
||||
"test/existing_model.sft",
|
||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
||||
file_path = tmp_path / "non_existing_model.sft"
|
||||
|
||||
|
||||
mock_callback = AsyncMock()
|
||||
|
||||
|
||||
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
|
||||
|
||||
|
||||
assert result is None
|
||||
mock_callback.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_download_progress_no_content_length():
|
||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||
@ -240,6 +250,7 @@ async def test_track_download_progress_no_content_length():
|
||||
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_download_progress_interval():
|
||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||
@ -254,7 +265,7 @@ async def test_track_download_progress_interval():
|
||||
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
||||
|
||||
with patch('builtins.open', mock_open), \
|
||||
patch('time.time', mock_time):
|
||||
patch('time.time', mock_time):
|
||||
await track_download_progress(
|
||||
mock_response, '/mock/path/model.sft', 'model.sft',
|
||||
mock_callback, 'models/model.sft', interval=1.0
|
||||
@ -279,31 +290,39 @@ async def test_track_download_progress_interval():
|
||||
assert last_call[0][1].status == "completed"
|
||||
assert last_call[0][1].progress_percentage == 100
|
||||
|
||||
|
||||
def test_valid_subdirectory():
|
||||
assert validate_model_subdirectory("valid-model123") is True
|
||||
|
||||
|
||||
def test_subdirectory_too_long():
|
||||
assert validate_model_subdirectory("a" * 51) is False
|
||||
|
||||
|
||||
def test_subdirectory_with_double_dots():
|
||||
assert validate_model_subdirectory("model/../unsafe") is False
|
||||
|
||||
|
||||
def test_subdirectory_with_slash():
|
||||
assert validate_model_subdirectory("model/unsafe") is False
|
||||
|
||||
|
||||
def test_subdirectory_with_special_characters():
|
||||
assert validate_model_subdirectory("model@unsafe") is False
|
||||
|
||||
|
||||
def test_subdirectory_with_underscore_and_dash():
|
||||
assert validate_model_subdirectory("valid_model-name") is True
|
||||
|
||||
|
||||
def test_empty_subdirectory():
|
||||
assert validate_model_subdirectory("") is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize("filename, expected", [
|
||||
("valid_model.safetensors", True),
|
||||
("valid_model.sft", True),
|
||||
("valid model.safetensors", True), # Test with space
|
||||
("valid model.safetensors", True), # Test with space
|
||||
("UPPERCASE_MODEL.SAFETENSORS", True),
|
||||
("model_with.multiple.dots.pt", False),
|
||||
("", False), # Empty string
|
||||
|
||||
Loading…
Reference in New Issue
Block a user