Merges new frontend

- fixes bfloat16 on cpu to numpy issues
 - extensions should go into comfy/web/extensions/javascript
This commit is contained in:
doctorpangloss 2024-08-16 15:46:11 -07:00
parent 940fe6e3af
commit f1a096b3e1
12 changed files with 162 additions and 69 deletions

View File

@ -172,14 +172,12 @@ class PromptServer(ExecutorToClientProgress):
@routes.get("/extensions") @routes.get("/extensions")
async def get_extensions(request): async def get_extensions(request):
files = glob.glob(os.path.join( files = glob.glob(os.path.join(glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)) extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
for name, dir in self.nodes.EXTENSION_WEB_DIRS.items(): for name, dir in self.nodes.EXTENSION_WEB_DIRS.items():
files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True) files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True)
extensions.extend(list(map(lambda f: "/extensions/" + quote( extensions.extend(list(map(lambda f: "/extensions/" + quote(name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
return web.json_response(extensions) return web.json_response(extensions)
@ -444,8 +442,9 @@ class PromptServer(ExecutorToClientProgress):
info = {} info = {}
info['input'] = obj_class.INPUT_TYPES() info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
info['output'] = obj_class.RETURN_TYPES _return_types = ["STRING" if isinstance(rt, list) and rt == [] else rt for rt in obj_class.RETURN_TYPES]
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output'] = _return_types
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(_return_types)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class info['name'] = node_class
info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class

View File

@ -397,6 +397,8 @@ KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
KNOWN_APPROX_VAES: Final[KnownDownloadables] = KnownDownloadables([ KNOWN_APPROX_VAES: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"), HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"),
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors"), HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors"),
HuggingFile("madebyollin/taef1", "diffusion_pytorch_model.safetensors", save_with_filename="taef1_decoder.safetensors"),
HuggingFile("madebyollin/taesd3", "diffusion_pytorch_model.safetensors", save_with_filename="taesd3_decoder.safetensors"),
], folder_name="vae_approx") ], folder_name="vae_approx")
KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([ KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([

View File

@ -1,14 +1,17 @@
from __future__ import annotations from __future__ import annotations
import aiohttp
import os
import traceback
import logging import logging
from folder_paths import models_dir import os
import re import re
from typing import Callable, Any, Optional, Awaitable, Dict
from enum import Enum
import time import time
import traceback
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from typing import Callable, Any, Optional, Awaitable, Dict
import aiohttp
from ..cmd.folder_paths import models_dir
class DownloadStatusType(Enum): class DownloadStatusType(Enum):
@ -17,6 +20,7 @@ class DownloadStatusType(Enum):
COMPLETED = "completed" COMPLETED = "completed"
ERROR = "error" ERROR = "error"
@dataclass @dataclass
class DownloadModelStatus(): class DownloadModelStatus():
status: str status: str
@ -29,7 +33,7 @@ class DownloadModelStatus():
self.progress_percentage = progress_percentage self.progress_percentage = progress_percentage
self.message = message self.message = message
self.already_existed = already_existed self.already_existed = already_existed
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {
"status": self.status, "status": self.status,
@ -38,9 +42,10 @@ class DownloadModelStatus():
"already_existed": self.already_existed "already_existed": self.already_existed
} }
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str, model_name: str,
model_url: str, model_url: str,
model_sub_directory: str, model_sub_directory: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
progress_interval: float = 1.0) -> DownloadModelStatus: progress_interval: float = 1.0) -> DownloadModelStatus:
@ -65,17 +70,17 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
""" """
if not validate_model_subdirectory(model_sub_directory): if not validate_model_subdirectory(model_sub_directory):
return DownloadModelStatus( return DownloadModelStatus(
DownloadStatusType.ERROR, DownloadStatusType.ERROR,
0, 0,
"Invalid model subdirectory", "Invalid model subdirectory",
False False
) )
if not validate_filename(model_name): if not validate_filename(model_name):
return DownloadModelStatus( return DownloadModelStatus(
DownloadStatusType.ERROR, DownloadStatusType.ERROR,
0, 0,
"Invalid model name", "Invalid model name",
False False
) )
@ -101,7 +106,7 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
except Exception as e: except Exception as e:
logging.error(f"Error in downloading model: {e}") logging.error(f"Error in downloading model: {e}")
return await handle_download_error(e, model_name, progress_callback, relative_path) return await handle_download_error(e, model_name, progress_callback, relative_path)
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]: def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
full_model_dir = os.path.join(models_base_dir, model_directory) full_model_dir = os.path.join(models_base_dir, model_directory)
@ -114,13 +119,13 @@ def create_model_path(model_name: str, model_directory: str, models_base_dir: st
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir: if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
raise Exception(f"Invalid model directory: {model_directory}/{model_name}") raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
relative_path = '/'.join([model_directory, model_name]) relative_path = '/'.join([model_directory, model_name])
return file_path, relative_path return file_path, relative_path
async def check_file_exists(file_path: str,
model_name: str, async def check_file_exists(file_path: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str) -> Optional[DownloadModelStatus]: relative_path: str) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path): if os.path.exists(file_path):
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True) status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
@ -129,11 +134,11 @@ async def check_file_exists(file_path: str,
return None return None
async def track_download_progress(response: aiohttp.ClientResponse, async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str, file_path: str,
model_name: str, model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str, relative_path: str,
interval: float = 1.0) -> DownloadModelStatus: interval: float = 1.0) -> DownloadModelStatus:
try: try:
total_size = int(response.headers.get('Content-Length', 0)) total_size = int(response.headers.get('Content-Length', 0))
@ -156,12 +161,12 @@ async def track_download_progress(response: aiohttp.ClientResponse,
break break
f.write(chunk) f.write(chunk)
downloaded += len(chunk) downloaded += len(chunk)
if time.time() - last_update_time >= interval: if time.time() - last_update_time >= interval:
await update_progress() await update_progress()
await update_progress() await update_progress()
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False) status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
@ -172,15 +177,17 @@ async def track_download_progress(response: aiohttp.ClientResponse,
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback, relative_path) return await handle_download_error(e, model_name, progress_callback, relative_path)
async def handle_download_error(e: Exception,
model_name: str, async def handle_download_error(e: Exception,
progress_callback: Callable[[str, DownloadModelStatus], Any], model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Any],
relative_path: str) -> DownloadModelStatus: relative_path: str) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}" error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status) await progress_callback(relative_path, status)
return status return status
def validate_model_subdirectory(model_subdirectory: str) -> bool: def validate_model_subdirectory(model_subdirectory: str) -> bool:
""" """
Validate that the model subdirectory is safe to install into. Validate that the model subdirectory is safe to install into.
@ -204,7 +211,8 @@ def validate_model_subdirectory(model_subdirectory: str) -> bool:
return True return True
def validate_filename(filename: str)-> bool:
def validate_filename(filename: str) -> bool:
""" """
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts. Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.

View File

@ -24,6 +24,7 @@ from .. import model_management
from ..cli_args import args from ..cli_args import args
from ..cmd import folder_paths, latent_preview from ..cmd import folder_paths, latent_preview
from ..component_model.tensor_types import RGBImage
from ..execution_context import current_execution_context from ..execution_context import current_execution_context
from ..images import open_image from ..images import open_image
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
@ -718,8 +719,8 @@ class VAELoader:
sd_["vae_scale"] = torch.tensor(1.5305) sd_["vae_scale"] = torch.tensor(1.5305)
sd_["vae_shift"] = torch.tensor(0.0609) sd_["vae_shift"] = torch.tensor(0.0609)
elif name == "taef1": elif name == "taef1":
sd["vae_scale"] = torch.tensor(0.3611) sd_["vae_scale"] = torch.tensor(0.3611)
sd["vae_shift"] = torch.tensor(0.1159) sd_["vae_shift"] = torch.tensor(0.1159)
return sd_ return sd_
@classmethod @classmethod
@ -1492,8 +1493,9 @@ class SaveImage:
filename_prefix += self.prefix_append filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list() results = list()
image: RGBImage
for (batch_number, image) in enumerate(images): for (batch_number, image) in enumerate(images):
i = 255. * image.cpu().numpy() i = 255. * image.float().cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = None metadata = None
if not args.disable_metadata: if not args.disable_metadata:

View File

@ -0,0 +1,62 @@
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
import { ComfyWidgets } from "../../scripts/widgets.js";
const tokenPreviewWidgetName = "__tokens";
class TokenProgressHandler {
constructor() {
this.nodeOutputs = {};
this.initEventListeners();
}
initEventListeners() {
api.addEventListener("executing", ({ detail }) => {
if (!detail) {
return;
}
const nodeId = detail;
if (!this.nodeOutputs[nodeId]) {
this.nodeOutputs[nodeId] = {};
}
this.nodeOutputs[nodeId].tokens = null;
});
api.addEventListener("progress", ({ detail }) => {
const nodeId = detail.node;
if (!this.nodeOutputs[nodeId]) {
this.nodeOutputs[nodeId] = {};
}
if (detail.output && detail.output.next_token) {
if (!this.nodeOutputs[nodeId].tokens) {
this.nodeOutputs[nodeId].tokens = "";
}
this.nodeOutputs[nodeId].tokens += detail.output.next_token;
this.updateTokenWidget(nodeId, this.nodeOutputs[nodeId].tokens);
}
app.graph.setDirtyCanvas(true, false);
});
}
updateTokenWidget(nodeId, tokens) {
const node = app.graph.getNodeById(nodeId);
if (node && node.widgets) {
let widget = node.widgets.find((w) => w.name === tokenPreviewWidgetName);
if (!widget) {
widget = ComfyWidgets["STRING"](node, tokenPreviewWidgetName, ["STRING", { multiline: true }], app).widget;
widget.inputEl.readOnly = true;
widget.inputEl.style.opacity = 0.7;
}
widget.value = tokens;
app.graph.setDirtyCanvas(true, false);
}
}
}
app.registerExtension({
name: "Comfy.TokenProgress",
setup() {
this.tokenProgressHandler = new TokenProgressHandler();
},
});

View File

@ -70,7 +70,7 @@ class ImageApplyColorMap(CustomNode):
cv2_colormap = getattr(cv2, colormap) cv2_colormap = getattr(cv2, colormap)
colored_image = cv2.applyColorMap(normalized_depth_uint8, cv2_colormap) # pylint: disable=no-member colored_image = cv2.applyColorMap(normalized_depth_uint8, cv2_colormap) # pylint: disable=no-member
colored_image_rgb = cv2.cvtColor(colored_image, cv2.COLOR_BGR2RGB) # pylint: disable=no-member colored_image_rgb = cv2.cvtColor(colored_image, cv2.COLOR_BGR2RGB) # pylint: disable=no-member
rgb_tensor = torch.tensor(colored_image_rgb) * 1.0 / 255.0 rgb_tensor = torch.tensor(colored_image_rgb).float() * 1.0 / 255.0
colored_images.append(rgb_tensor) colored_images.append(rgb_tensor)
return torch.stack(colored_images), return torch.stack(colored_images),

View File

@ -101,7 +101,7 @@ class SaveAnimatedWEBP:
results = list() results = list()
pil_images = [] pil_images = []
for image in images: for image in images:
i = 255. * image.cpu().numpy() i = 255. * image.float().cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img) pil_images.append(img)
@ -162,7 +162,7 @@ class SaveAnimatedPNG:
results = list() results = list()
pil_images = [] pil_images = []
for image in images: for image in images:
i = 255. * image.cpu().numpy() i = 255. * image.float().cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img) pil_images.append(img)

View File

@ -197,7 +197,7 @@ class HashImage(CustomNode):
def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult: def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult:
def process_image(image: Tensor) -> str: def process_image(image: Tensor) -> str:
image_as_numpy_array: np.ndarray = 255. * image.cpu().numpy() image_as_numpy_array: np.ndarray = 255. * image.float().cpu().numpy()
image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8)) image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8))
data = image_as_numpy_array.data data = image_as_numpy_array.data
try: try:

View File

@ -31,7 +31,7 @@ class SkeletonizeThin:
results = [] results = []
for img in binary: for img in binary:
img_np = img.squeeze().cpu().numpy() img_np = img.squeeze().float().cpu().numpy()
if use_skeletonize: if use_skeletonize:
result = skeletonize(img_np) result = skeletonize(img_np)

View File

@ -62,3 +62,4 @@ networkx>=2.6.3
joblib joblib
jaxtyping jaxtyping
spandrel_extra_arches spandrel_extra_arches
ml_dtypes

View File

@ -1,16 +1,21 @@
import pytest
import aiohttp
from aiohttp import ClientResponse
import itertools import itertools
import os import os
from unittest.mock import AsyncMock, patch, MagicMock from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
import aiohttp
import pytest
from aiohttp import ClientResponse
from comfy.model_filemanager import download_model, validate_model_subdirectory, track_download_progress, \
create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
class AsyncIteratorMock: class AsyncIteratorMock:
""" """
A mock class that simulates an asynchronous iterator. A mock class that simulates an asynchronous iterator.
This is used to mimic the behavior of aiohttp's content iterator. This is used to mimic the behavior of aiohttp's content iterator.
""" """
def __init__(self, seq): def __init__(self, seq):
# Convert the input sequence into an iterator # Convert the input sequence into an iterator
self.iter = iter(seq) self.iter = iter(seq)
@ -27,11 +32,13 @@ class AsyncIteratorMock:
# This is the asynchronous equivalent of StopIteration # This is the asynchronous equivalent of StopIteration
raise StopAsyncIteration raise StopAsyncIteration
class ContentMock: class ContentMock:
""" """
A mock class that simulates the content attribute of an aiohttp ClientResponse. A mock class that simulates the content attribute of an aiohttp ClientResponse.
This class provides the iter_chunked method which returns an async iterator of chunks. This class provides the iter_chunked method which returns an async iterator of chunks.
""" """
def __init__(self, chunks): def __init__(self, chunks):
# Store the chunks that will be returned by the iterator # Store the chunks that will be returned by the iterator
self.chunks = chunks self.chunks = chunks
@ -41,6 +48,7 @@ class ContentMock:
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks # For simplicity in testing, we ignore chunk_size and just return our predefined chunks
return AsyncIteratorMock(self.chunks) return AsyncIteratorMock(self.chunks)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_model_success(): async def test_download_model_success():
mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response = AsyncMock(spec=aiohttp.ClientResponse)
@ -59,10 +67,10 @@ async def test_download_model_success():
mock_open.return_value.__enter__.return_value = mock_file mock_open.return_value.__enter__.return_value = mock_file
time_values = itertools.count(0, 0.1) time_values = itertools.count(0, 0.1)
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \ with patch('comfy.model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
patch('model_filemanager.check_file_exists', return_value=None), \ patch('comfy.model_filemanager.check_file_exists', return_value=None), \
patch('builtins.open', mock_open), \ patch('builtins.open', mock_open), \
patch('time.time', side_effect=time_values): # Simulate time passing patch('time.time', side_effect=time_values): # Simulate time passing
result = await download_model( result = await download_model(
mock_make_request, mock_make_request,
@ -80,7 +88,7 @@ async def test_download_model_success():
# Check progress callback calls # Check progress callback calls
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
# Check initial call # Check initial call
mock_progress_callback.assert_any_call( mock_progress_callback.assert_any_call(
'checkpoints/model.sft', 'checkpoints/model.sft',
@ -101,6 +109,7 @@ async def test_download_model_success():
# Verify request was made # Verify request was made
mock_make_request.assert_called_once_with('http://example.com/model.sft') mock_make_request.assert_called_once_with('http://example.com/model.sft')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_model_url_request_failure(): async def test_download_model_url_request_failure():
# Mock dependencies # Mock dependencies
@ -110,9 +119,9 @@ async def test_download_model_url_request_failure():
mock_progress_callback = AsyncMock() mock_progress_callback = AsyncMock()
# Mock the create_model_path function # Mock the create_model_path function
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')): with patch('comfy.model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
# Mock the check_file_exists function to return None (file doesn't exist) # Mock the check_file_exists function to return None (file doesn't exist)
with patch('model_filemanager.check_file_exists', return_value=None): with patch('comfy.model_filemanager.check_file_exists', return_value=None):
# Call the function # Call the function
result = await download_model( result = await download_model(
mock_get, mock_get,
@ -151,13 +160,12 @@ async def test_download_model_url_request_failure():
# Verify that the get method was called with the correct URL # Verify that the get method was called with the correct URL
mock_get.assert_called_once_with('http://example.com/model.safetensors') mock_get.assert_called_once_with('http://example.com/model.safetensors')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_model_invalid_model_subdirectory(): async def test_download_model_invalid_model_subdirectory():
mock_make_request = AsyncMock() mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock() mock_progress_callback = AsyncMock()
result = await download_model( result = await download_model(
mock_make_request, mock_make_request,
'model.sft', 'model.sft',
@ -177,12 +185,12 @@ async def test_download_model_invalid_model_subdirectory():
def test_create_model_path(tmp_path, monkeypatch): def test_create_model_path(tmp_path, monkeypatch):
mock_models_dir = tmp_path / "models" mock_models_dir = tmp_path / "models"
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir)) monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
model_name = "test_model.sft" model_name = "test_model.sft"
model_directory = "test_dir" model_directory = "test_dir"
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir) file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
assert file_path == str(mock_models_dir / model_directory / model_name) assert file_path == str(mock_models_dir / model_directory / model_name)
assert relative_path == f"{model_directory}/{model_name}" assert relative_path == f"{model_directory}/{model_name}"
assert os.path.exists(os.path.dirname(file_path)) assert os.path.exists(os.path.dirname(file_path))
@ -192,32 +200,34 @@ def test_create_model_path(tmp_path, monkeypatch):
async def test_check_file_exists_when_file_exists(tmp_path): async def test_check_file_exists_when_file_exists(tmp_path):
file_path = tmp_path / "existing_model.sft" file_path = tmp_path / "existing_model.sft"
file_path.touch() # Create an empty file file_path.touch() # Create an empty file
mock_callback = AsyncMock() mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft") result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
assert result is not None assert result is not None
assert result.status == "completed" assert result.status == "completed"
assert result.message == "existing_model.sft already exists" assert result.message == "existing_model.sft already exists"
assert result.already_existed is True assert result.already_existed is True
mock_callback.assert_called_once_with( mock_callback.assert_called_once_with(
"test/existing_model.sft", "test/existing_model.sft",
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True) DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_file_exists_when_file_does_not_exist(tmp_path): async def test_check_file_exists_when_file_does_not_exist(tmp_path):
file_path = tmp_path / "non_existing_model.sft" file_path = tmp_path / "non_existing_model.sft"
mock_callback = AsyncMock() mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft") result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
assert result is None assert result is None
mock_callback.assert_not_called() mock_callback.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_track_download_progress_no_content_length(): async def test_track_download_progress_no_content_length():
mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response = AsyncMock(spec=aiohttp.ClientResponse)
@ -240,6 +250,7 @@ async def test_track_download_progress_no_content_length():
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False) DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_track_download_progress_interval(): async def test_track_download_progress_interval():
mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response = AsyncMock(spec=aiohttp.ClientResponse)
@ -254,7 +265,7 @@ async def test_track_download_progress_interval():
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
with patch('builtins.open', mock_open), \ with patch('builtins.open', mock_open), \
patch('time.time', mock_time): patch('time.time', mock_time):
await track_download_progress( await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft', mock_response, '/mock/path/model.sft', 'model.sft',
mock_callback, 'models/model.sft', interval=1.0 mock_callback, 'models/model.sft', interval=1.0
@ -279,31 +290,39 @@ async def test_track_download_progress_interval():
assert last_call[0][1].status == "completed" assert last_call[0][1].status == "completed"
assert last_call[0][1].progress_percentage == 100 assert last_call[0][1].progress_percentage == 100
def test_valid_subdirectory(): def test_valid_subdirectory():
assert validate_model_subdirectory("valid-model123") is True assert validate_model_subdirectory("valid-model123") is True
def test_subdirectory_too_long(): def test_subdirectory_too_long():
assert validate_model_subdirectory("a" * 51) is False assert validate_model_subdirectory("a" * 51) is False
def test_subdirectory_with_double_dots(): def test_subdirectory_with_double_dots():
assert validate_model_subdirectory("model/../unsafe") is False assert validate_model_subdirectory("model/../unsafe") is False
def test_subdirectory_with_slash(): def test_subdirectory_with_slash():
assert validate_model_subdirectory("model/unsafe") is False assert validate_model_subdirectory("model/unsafe") is False
def test_subdirectory_with_special_characters(): def test_subdirectory_with_special_characters():
assert validate_model_subdirectory("model@unsafe") is False assert validate_model_subdirectory("model@unsafe") is False
def test_subdirectory_with_underscore_and_dash(): def test_subdirectory_with_underscore_and_dash():
assert validate_model_subdirectory("valid_model-name") is True assert validate_model_subdirectory("valid_model-name") is True
def test_empty_subdirectory(): def test_empty_subdirectory():
assert validate_model_subdirectory("") is False assert validate_model_subdirectory("") is False
@pytest.mark.parametrize("filename, expected", [ @pytest.mark.parametrize("filename, expected", [
("valid_model.safetensors", True), ("valid_model.safetensors", True),
("valid_model.sft", True), ("valid_model.sft", True),
("valid model.safetensors", True), # Test with space ("valid model.safetensors", True), # Test with space
("UPPERCASE_MODEL.SAFETENSORS", True), ("UPPERCASE_MODEL.SAFETENSORS", True),
("model_with.multiple.dots.pt", False), ("model_with.multiple.dots.pt", False),
("", False), # Empty string ("", False), # Empty string