mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Merges new frontend
- fixes bfloat16 on cpu to numpy issues - extensions should go into comfy/web/extensions/javascript
This commit is contained in:
parent
940fe6e3af
commit
f1a096b3e1
@ -172,14 +172,12 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
|
|
||||||
@routes.get("/extensions")
|
@routes.get("/extensions")
|
||||||
async def get_extensions(request):
|
async def get_extensions(request):
|
||||||
files = glob.glob(os.path.join(
|
files = glob.glob(os.path.join(glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
|
||||||
glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
|
|
||||||
extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
|
extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
|
||||||
|
|
||||||
for name, dir in self.nodes.EXTENSION_WEB_DIRS.items():
|
for name, dir in self.nodes.EXTENSION_WEB_DIRS.items():
|
||||||
files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True)
|
files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True)
|
||||||
extensions.extend(list(map(lambda f: "/extensions/" + quote(
|
extensions.extend(list(map(lambda f: "/extensions/" + quote(name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
|
||||||
name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
|
|
||||||
|
|
||||||
return web.json_response(extensions)
|
return web.json_response(extensions)
|
||||||
|
|
||||||
@ -444,8 +442,9 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
info = {}
|
info = {}
|
||||||
info['input'] = obj_class.INPUT_TYPES()
|
info['input'] = obj_class.INPUT_TYPES()
|
||||||
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||||
info['output'] = obj_class.RETURN_TYPES
|
_return_types = ["STRING" if isinstance(rt, list) and rt == [] else rt for rt in obj_class.RETURN_TYPES]
|
||||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
info['output'] = _return_types
|
||||||
|
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(_return_types)
|
||||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||||
info['name'] = node_class
|
info['name'] = node_class
|
||||||
info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||||
|
|||||||
@ -397,6 +397,8 @@ KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
|
|||||||
KNOWN_APPROX_VAES: Final[KnownDownloadables] = KnownDownloadables([
|
KNOWN_APPROX_VAES: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"),
|
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"),
|
||||||
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors"),
|
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors"),
|
||||||
|
HuggingFile("madebyollin/taef1", "diffusion_pytorch_model.safetensors", save_with_filename="taef1_decoder.safetensors"),
|
||||||
|
HuggingFile("madebyollin/taesd3", "diffusion_pytorch_model.safetensors", save_with_filename="taesd3_decoder.safetensors"),
|
||||||
], folder_name="vae_approx")
|
], folder_name="vae_approx")
|
||||||
|
|
||||||
KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([
|
KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
|
|||||||
@ -1,14 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import aiohttp
|
|
||||||
import os
|
|
||||||
import traceback
|
|
||||||
import logging
|
import logging
|
||||||
from folder_paths import models_dir
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Any, Optional, Awaitable, Dict
|
|
||||||
from enum import Enum
|
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Callable, Any, Optional, Awaitable, Dict
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from ..cmd.folder_paths import models_dir
|
||||||
|
|
||||||
|
|
||||||
class DownloadStatusType(Enum):
|
class DownloadStatusType(Enum):
|
||||||
@ -17,6 +20,7 @@ class DownloadStatusType(Enum):
|
|||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DownloadModelStatus():
|
class DownloadModelStatus():
|
||||||
status: str
|
status: str
|
||||||
@ -29,7 +33,7 @@ class DownloadModelStatus():
|
|||||||
self.progress_percentage = progress_percentage
|
self.progress_percentage = progress_percentage
|
||||||
self.message = message
|
self.message = message
|
||||||
self.already_existed = already_existed
|
self.already_existed = already_existed
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"status": self.status,
|
"status": self.status,
|
||||||
@ -38,9 +42,10 @@ class DownloadModelStatus():
|
|||||||
"already_existed": self.already_existed
|
"already_existed": self.already_existed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_url: str,
|
model_url: str,
|
||||||
model_sub_directory: str,
|
model_sub_directory: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
progress_interval: float = 1.0) -> DownloadModelStatus:
|
progress_interval: float = 1.0) -> DownloadModelStatus:
|
||||||
@ -65,17 +70,17 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
|
|||||||
"""
|
"""
|
||||||
if not validate_model_subdirectory(model_sub_directory):
|
if not validate_model_subdirectory(model_sub_directory):
|
||||||
return DownloadModelStatus(
|
return DownloadModelStatus(
|
||||||
DownloadStatusType.ERROR,
|
DownloadStatusType.ERROR,
|
||||||
0,
|
0,
|
||||||
"Invalid model subdirectory",
|
"Invalid model subdirectory",
|
||||||
False
|
False
|
||||||
)
|
)
|
||||||
|
|
||||||
if not validate_filename(model_name):
|
if not validate_filename(model_name):
|
||||||
return DownloadModelStatus(
|
return DownloadModelStatus(
|
||||||
DownloadStatusType.ERROR,
|
DownloadStatusType.ERROR,
|
||||||
0,
|
0,
|
||||||
"Invalid model name",
|
"Invalid model name",
|
||||||
False
|
False
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -101,7 +106,7 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in downloading model: {e}")
|
logging.error(f"Error in downloading model: {e}")
|
||||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||||
|
|
||||||
|
|
||||||
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
|
def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]:
|
||||||
full_model_dir = os.path.join(models_base_dir, model_directory)
|
full_model_dir = os.path.join(models_base_dir, model_directory)
|
||||||
@ -114,13 +119,13 @@ def create_model_path(model_name: str, model_directory: str, models_base_dir: st
|
|||||||
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
|
||||||
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
|
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")
|
||||||
|
|
||||||
|
|
||||||
relative_path = '/'.join([model_directory, model_name])
|
relative_path = '/'.join([model_directory, model_name])
|
||||||
return file_path, relative_path
|
return file_path, relative_path
|
||||||
|
|
||||||
async def check_file_exists(file_path: str,
|
|
||||||
model_name: str,
|
async def check_file_exists(file_path: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
model_name: str,
|
||||||
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
relative_path: str) -> Optional[DownloadModelStatus]:
|
relative_path: str) -> Optional[DownloadModelStatus]:
|
||||||
if os.path.exists(file_path):
|
if os.path.exists(file_path):
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
|
||||||
@ -129,11 +134,11 @@ async def check_file_exists(file_path: str,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def track_download_progress(response: aiohttp.ClientResponse,
|
async def track_download_progress(response: aiohttp.ClientResponse,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
|
||||||
relative_path: str,
|
relative_path: str,
|
||||||
interval: float = 1.0) -> DownloadModelStatus:
|
interval: float = 1.0) -> DownloadModelStatus:
|
||||||
try:
|
try:
|
||||||
total_size = int(response.headers.get('Content-Length', 0))
|
total_size = int(response.headers.get('Content-Length', 0))
|
||||||
@ -156,12 +161,12 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
break
|
break
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
downloaded += len(chunk)
|
downloaded += len(chunk)
|
||||||
|
|
||||||
if time.time() - last_update_time >= interval:
|
if time.time() - last_update_time >= interval:
|
||||||
await update_progress()
|
await update_progress()
|
||||||
|
|
||||||
await update_progress()
|
await update_progress()
|
||||||
|
|
||||||
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
|
||||||
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
@ -172,15 +177,17 @@ async def track_download_progress(response: aiohttp.ClientResponse,
|
|||||||
logging.error(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
return await handle_download_error(e, model_name, progress_callback, relative_path)
|
||||||
|
|
||||||
async def handle_download_error(e: Exception,
|
|
||||||
model_name: str,
|
async def handle_download_error(e: Exception,
|
||||||
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
model_name: str,
|
||||||
|
progress_callback: Callable[[str, DownloadModelStatus], Any],
|
||||||
relative_path: str) -> DownloadModelStatus:
|
relative_path: str) -> DownloadModelStatus:
|
||||||
error_message = f"Error downloading {model_name}: {str(e)}"
|
error_message = f"Error downloading {model_name}: {str(e)}"
|
||||||
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
|
||||||
await progress_callback(relative_path, status)
|
await progress_callback(relative_path, status)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
|
||||||
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Validate that the model subdirectory is safe to install into.
|
Validate that the model subdirectory is safe to install into.
|
||||||
@ -204,7 +211,8 @@ def validate_model_subdirectory(model_subdirectory: str) -> bool:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def validate_filename(filename: str)-> bool:
|
|
||||||
|
def validate_filename(filename: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
|
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
|
||||||
|
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from .. import model_management
|
|||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
|
|
||||||
from ..cmd import folder_paths, latent_preview
|
from ..cmd import folder_paths, latent_preview
|
||||||
|
from ..component_model.tensor_types import RGBImage
|
||||||
from ..execution_context import current_execution_context
|
from ..execution_context import current_execution_context
|
||||||
from ..images import open_image
|
from ..images import open_image
|
||||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
|
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
|
||||||
@ -718,8 +719,8 @@ class VAELoader:
|
|||||||
sd_["vae_scale"] = torch.tensor(1.5305)
|
sd_["vae_scale"] = torch.tensor(1.5305)
|
||||||
sd_["vae_shift"] = torch.tensor(0.0609)
|
sd_["vae_shift"] = torch.tensor(0.0609)
|
||||||
elif name == "taef1":
|
elif name == "taef1":
|
||||||
sd["vae_scale"] = torch.tensor(0.3611)
|
sd_["vae_scale"] = torch.tensor(0.3611)
|
||||||
sd["vae_shift"] = torch.tensor(0.1159)
|
sd_["vae_shift"] = torch.tensor(0.1159)
|
||||||
return sd_
|
return sd_
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1492,8 +1493,9 @@ class SaveImage:
|
|||||||
filename_prefix += self.prefix_append
|
filename_prefix += self.prefix_append
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||||
results = list()
|
results = list()
|
||||||
|
image: RGBImage
|
||||||
for (batch_number, image) in enumerate(images):
|
for (batch_number, image) in enumerate(images):
|
||||||
i = 255. * image.cpu().numpy()
|
i = 255. * image.float().cpu().numpy()
|
||||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||||
metadata = None
|
metadata = None
|
||||||
if not args.disable_metadata:
|
if not args.disable_metadata:
|
||||||
|
|||||||
62
comfy/web/extensions/javascript/language.js
vendored
Normal file
62
comfy/web/extensions/javascript/language.js
vendored
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import { app } from "../../scripts/app.js";
|
||||||
|
import { api } from "../../scripts/api.js";
|
||||||
|
import { ComfyWidgets } from "../../scripts/widgets.js";
|
||||||
|
|
||||||
|
const tokenPreviewWidgetName = "__tokens";
|
||||||
|
|
||||||
|
class TokenProgressHandler {
|
||||||
|
constructor() {
|
||||||
|
this.nodeOutputs = {};
|
||||||
|
this.initEventListeners();
|
||||||
|
}
|
||||||
|
|
||||||
|
initEventListeners() {
|
||||||
|
api.addEventListener("executing", ({ detail }) => {
|
||||||
|
if (!detail) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const nodeId = detail;
|
||||||
|
if (!this.nodeOutputs[nodeId]) {
|
||||||
|
this.nodeOutputs[nodeId] = {};
|
||||||
|
}
|
||||||
|
this.nodeOutputs[nodeId].tokens = null;
|
||||||
|
});
|
||||||
|
|
||||||
|
api.addEventListener("progress", ({ detail }) => {
|
||||||
|
const nodeId = detail.node;
|
||||||
|
if (!this.nodeOutputs[nodeId]) {
|
||||||
|
this.nodeOutputs[nodeId] = {};
|
||||||
|
}
|
||||||
|
if (detail.output && detail.output.next_token) {
|
||||||
|
if (!this.nodeOutputs[nodeId].tokens) {
|
||||||
|
this.nodeOutputs[nodeId].tokens = "";
|
||||||
|
}
|
||||||
|
this.nodeOutputs[nodeId].tokens += detail.output.next_token;
|
||||||
|
this.updateTokenWidget(nodeId, this.nodeOutputs[nodeId].tokens);
|
||||||
|
}
|
||||||
|
app.graph.setDirtyCanvas(true, false);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
updateTokenWidget(nodeId, tokens) {
|
||||||
|
const node = app.graph.getNodeById(nodeId);
|
||||||
|
if (node && node.widgets) {
|
||||||
|
let widget = node.widgets.find((w) => w.name === tokenPreviewWidgetName);
|
||||||
|
|
||||||
|
if (!widget) {
|
||||||
|
widget = ComfyWidgets["STRING"](node, tokenPreviewWidgetName, ["STRING", { multiline: true }], app).widget;
|
||||||
|
widget.inputEl.readOnly = true;
|
||||||
|
widget.inputEl.style.opacity = 0.7;
|
||||||
|
}
|
||||||
|
widget.value = tokens;
|
||||||
|
app.graph.setDirtyCanvas(true, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
app.registerExtension({
|
||||||
|
name: "Comfy.TokenProgress",
|
||||||
|
setup() {
|
||||||
|
this.tokenProgressHandler = new TokenProgressHandler();
|
||||||
|
},
|
||||||
|
});
|
||||||
@ -70,7 +70,7 @@ class ImageApplyColorMap(CustomNode):
|
|||||||
cv2_colormap = getattr(cv2, colormap)
|
cv2_colormap = getattr(cv2, colormap)
|
||||||
colored_image = cv2.applyColorMap(normalized_depth_uint8, cv2_colormap) # pylint: disable=no-member
|
colored_image = cv2.applyColorMap(normalized_depth_uint8, cv2_colormap) # pylint: disable=no-member
|
||||||
colored_image_rgb = cv2.cvtColor(colored_image, cv2.COLOR_BGR2RGB) # pylint: disable=no-member
|
colored_image_rgb = cv2.cvtColor(colored_image, cv2.COLOR_BGR2RGB) # pylint: disable=no-member
|
||||||
rgb_tensor = torch.tensor(colored_image_rgb) * 1.0 / 255.0
|
rgb_tensor = torch.tensor(colored_image_rgb).float() * 1.0 / 255.0
|
||||||
colored_images.append(rgb_tensor)
|
colored_images.append(rgb_tensor)
|
||||||
|
|
||||||
return torch.stack(colored_images),
|
return torch.stack(colored_images),
|
||||||
|
|||||||
@ -101,7 +101,7 @@ class SaveAnimatedWEBP:
|
|||||||
results = list()
|
results = list()
|
||||||
pil_images = []
|
pil_images = []
|
||||||
for image in images:
|
for image in images:
|
||||||
i = 255. * image.cpu().numpy()
|
i = 255. * image.float().cpu().numpy()
|
||||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||||
pil_images.append(img)
|
pil_images.append(img)
|
||||||
|
|
||||||
@ -162,7 +162,7 @@ class SaveAnimatedPNG:
|
|||||||
results = list()
|
results = list()
|
||||||
pil_images = []
|
pil_images = []
|
||||||
for image in images:
|
for image in images:
|
||||||
i = 255. * image.cpu().numpy()
|
i = 255. * image.float().cpu().numpy()
|
||||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||||
pil_images.append(img)
|
pil_images.append(img)
|
||||||
|
|
||||||
|
|||||||
@ -197,7 +197,7 @@ class HashImage(CustomNode):
|
|||||||
|
|
||||||
def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult:
|
def execute(self, images: Sequence[Tensor]) -> ValidatedNodeResult:
|
||||||
def process_image(image: Tensor) -> str:
|
def process_image(image: Tensor) -> str:
|
||||||
image_as_numpy_array: np.ndarray = 255. * image.cpu().numpy()
|
image_as_numpy_array: np.ndarray = 255. * image.float().cpu().numpy()
|
||||||
image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8))
|
image_as_numpy_array = np.ascontiguousarray(np.clip(image_as_numpy_array, 0, 255).astype(np.uint8))
|
||||||
data = image_as_numpy_array.data
|
data = image_as_numpy_array.data
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class SkeletonizeThin:
|
|||||||
|
|
||||||
results = []
|
results = []
|
||||||
for img in binary:
|
for img in binary:
|
||||||
img_np = img.squeeze().cpu().numpy()
|
img_np = img.squeeze().float().cpu().numpy()
|
||||||
|
|
||||||
if use_skeletonize:
|
if use_skeletonize:
|
||||||
result = skeletonize(img_np)
|
result = skeletonize(img_np)
|
||||||
|
|||||||
@ -62,3 +62,4 @@ networkx>=2.6.3
|
|||||||
joblib
|
joblib
|
||||||
jaxtyping
|
jaxtyping
|
||||||
spandrel_extra_arches
|
spandrel_extra_arches
|
||||||
|
ml_dtypes
|
||||||
@ -1,16 +1,21 @@
|
|||||||
import pytest
|
|
||||||
import aiohttp
|
|
||||||
from aiohttp import ClientResponse
|
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
|
||||||
|
import aiohttp
|
||||||
|
import pytest
|
||||||
|
from aiohttp import ClientResponse
|
||||||
|
|
||||||
|
from comfy.model_filemanager import download_model, validate_model_subdirectory, track_download_progress, \
|
||||||
|
create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
||||||
|
|
||||||
|
|
||||||
class AsyncIteratorMock:
|
class AsyncIteratorMock:
|
||||||
"""
|
"""
|
||||||
A mock class that simulates an asynchronous iterator.
|
A mock class that simulates an asynchronous iterator.
|
||||||
This is used to mimic the behavior of aiohttp's content iterator.
|
This is used to mimic the behavior of aiohttp's content iterator.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, seq):
|
def __init__(self, seq):
|
||||||
# Convert the input sequence into an iterator
|
# Convert the input sequence into an iterator
|
||||||
self.iter = iter(seq)
|
self.iter = iter(seq)
|
||||||
@ -27,11 +32,13 @@ class AsyncIteratorMock:
|
|||||||
# This is the asynchronous equivalent of StopIteration
|
# This is the asynchronous equivalent of StopIteration
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
|
||||||
class ContentMock:
|
class ContentMock:
|
||||||
"""
|
"""
|
||||||
A mock class that simulates the content attribute of an aiohttp ClientResponse.
|
A mock class that simulates the content attribute of an aiohttp ClientResponse.
|
||||||
This class provides the iter_chunked method which returns an async iterator of chunks.
|
This class provides the iter_chunked method which returns an async iterator of chunks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, chunks):
|
def __init__(self, chunks):
|
||||||
# Store the chunks that will be returned by the iterator
|
# Store the chunks that will be returned by the iterator
|
||||||
self.chunks = chunks
|
self.chunks = chunks
|
||||||
@ -41,6 +48,7 @@ class ContentMock:
|
|||||||
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
|
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
|
||||||
return AsyncIteratorMock(self.chunks)
|
return AsyncIteratorMock(self.chunks)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_success():
|
async def test_download_model_success():
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
@ -59,10 +67,10 @@ async def test_download_model_success():
|
|||||||
mock_open.return_value.__enter__.return_value = mock_file
|
mock_open.return_value.__enter__.return_value = mock_file
|
||||||
time_values = itertools.count(0, 0.1)
|
time_values = itertools.count(0, 0.1)
|
||||||
|
|
||||||
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
|
with patch('comfy.model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \
|
||||||
patch('model_filemanager.check_file_exists', return_value=None), \
|
patch('comfy.model_filemanager.check_file_exists', return_value=None), \
|
||||||
patch('builtins.open', mock_open), \
|
patch('builtins.open', mock_open), \
|
||||||
patch('time.time', side_effect=time_values): # Simulate time passing
|
patch('time.time', side_effect=time_values): # Simulate time passing
|
||||||
|
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
mock_make_request,
|
mock_make_request,
|
||||||
@ -80,7 +88,7 @@ async def test_download_model_success():
|
|||||||
|
|
||||||
# Check progress callback calls
|
# Check progress callback calls
|
||||||
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
|
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
|
||||||
|
|
||||||
# Check initial call
|
# Check initial call
|
||||||
mock_progress_callback.assert_any_call(
|
mock_progress_callback.assert_any_call(
|
||||||
'checkpoints/model.sft',
|
'checkpoints/model.sft',
|
||||||
@ -101,6 +109,7 @@ async def test_download_model_success():
|
|||||||
# Verify request was made
|
# Verify request was made
|
||||||
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
mock_make_request.assert_called_once_with('http://example.com/model.sft')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_url_request_failure():
|
async def test_download_model_url_request_failure():
|
||||||
# Mock dependencies
|
# Mock dependencies
|
||||||
@ -110,9 +119,9 @@ async def test_download_model_url_request_failure():
|
|||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
# Mock the create_model_path function
|
# Mock the create_model_path function
|
||||||
with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
|
with patch('comfy.model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')):
|
||||||
# Mock the check_file_exists function to return None (file doesn't exist)
|
# Mock the check_file_exists function to return None (file doesn't exist)
|
||||||
with patch('model_filemanager.check_file_exists', return_value=None):
|
with patch('comfy.model_filemanager.check_file_exists', return_value=None):
|
||||||
# Call the function
|
# Call the function
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
mock_get,
|
mock_get,
|
||||||
@ -151,13 +160,12 @@ async def test_download_model_url_request_failure():
|
|||||||
# Verify that the get method was called with the correct URL
|
# Verify that the get method was called with the correct URL
|
||||||
mock_get.assert_called_once_with('http://example.com/model.safetensors')
|
mock_get.assert_called_once_with('http://example.com/model.safetensors')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_model_invalid_model_subdirectory():
|
async def test_download_model_invalid_model_subdirectory():
|
||||||
|
|
||||||
mock_make_request = AsyncMock()
|
mock_make_request = AsyncMock()
|
||||||
mock_progress_callback = AsyncMock()
|
mock_progress_callback = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
result = await download_model(
|
result = await download_model(
|
||||||
mock_make_request,
|
mock_make_request,
|
||||||
'model.sft',
|
'model.sft',
|
||||||
@ -177,12 +185,12 @@ async def test_download_model_invalid_model_subdirectory():
|
|||||||
def test_create_model_path(tmp_path, monkeypatch):
|
def test_create_model_path(tmp_path, monkeypatch):
|
||||||
mock_models_dir = tmp_path / "models"
|
mock_models_dir = tmp_path / "models"
|
||||||
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
|
monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir))
|
||||||
|
|
||||||
model_name = "test_model.sft"
|
model_name = "test_model.sft"
|
||||||
model_directory = "test_dir"
|
model_directory = "test_dir"
|
||||||
|
|
||||||
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
|
file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir)
|
||||||
|
|
||||||
assert file_path == str(mock_models_dir / model_directory / model_name)
|
assert file_path == str(mock_models_dir / model_directory / model_name)
|
||||||
assert relative_path == f"{model_directory}/{model_name}"
|
assert relative_path == f"{model_directory}/{model_name}"
|
||||||
assert os.path.exists(os.path.dirname(file_path))
|
assert os.path.exists(os.path.dirname(file_path))
|
||||||
@ -192,32 +200,34 @@ def test_create_model_path(tmp_path, monkeypatch):
|
|||||||
async def test_check_file_exists_when_file_exists(tmp_path):
|
async def test_check_file_exists_when_file_exists(tmp_path):
|
||||||
file_path = tmp_path / "existing_model.sft"
|
file_path = tmp_path / "existing_model.sft"
|
||||||
file_path.touch() # Create an empty file
|
file_path.touch() # Create an empty file
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
|
|
||||||
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
|
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft")
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
assert result.message == "existing_model.sft already exists"
|
assert result.message == "existing_model.sft already exists"
|
||||||
assert result.already_existed is True
|
assert result.already_existed is True
|
||||||
|
|
||||||
mock_callback.assert_called_once_with(
|
mock_callback.assert_called_once_with(
|
||||||
"test/existing_model.sft",
|
"test/existing_model.sft",
|
||||||
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
|
||||||
file_path = tmp_path / "non_existing_model.sft"
|
file_path = tmp_path / "non_existing_model.sft"
|
||||||
|
|
||||||
mock_callback = AsyncMock()
|
mock_callback = AsyncMock()
|
||||||
|
|
||||||
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
|
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft")
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
mock_callback.assert_not_called()
|
mock_callback.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_track_download_progress_no_content_length():
|
async def test_track_download_progress_no_content_length():
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
@ -240,6 +250,7 @@ async def test_track_download_progress_no_content_length():
|
|||||||
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_track_download_progress_interval():
|
async def test_track_download_progress_interval():
|
||||||
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
|
||||||
@ -254,7 +265,7 @@ async def test_track_download_progress_interval():
|
|||||||
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
|
||||||
|
|
||||||
with patch('builtins.open', mock_open), \
|
with patch('builtins.open', mock_open), \
|
||||||
patch('time.time', mock_time):
|
patch('time.time', mock_time):
|
||||||
await track_download_progress(
|
await track_download_progress(
|
||||||
mock_response, '/mock/path/model.sft', 'model.sft',
|
mock_response, '/mock/path/model.sft', 'model.sft',
|
||||||
mock_callback, 'models/model.sft', interval=1.0
|
mock_callback, 'models/model.sft', interval=1.0
|
||||||
@ -279,31 +290,39 @@ async def test_track_download_progress_interval():
|
|||||||
assert last_call[0][1].status == "completed"
|
assert last_call[0][1].status == "completed"
|
||||||
assert last_call[0][1].progress_percentage == 100
|
assert last_call[0][1].progress_percentage == 100
|
||||||
|
|
||||||
|
|
||||||
def test_valid_subdirectory():
|
def test_valid_subdirectory():
|
||||||
assert validate_model_subdirectory("valid-model123") is True
|
assert validate_model_subdirectory("valid-model123") is True
|
||||||
|
|
||||||
|
|
||||||
def test_subdirectory_too_long():
|
def test_subdirectory_too_long():
|
||||||
assert validate_model_subdirectory("a" * 51) is False
|
assert validate_model_subdirectory("a" * 51) is False
|
||||||
|
|
||||||
|
|
||||||
def test_subdirectory_with_double_dots():
|
def test_subdirectory_with_double_dots():
|
||||||
assert validate_model_subdirectory("model/../unsafe") is False
|
assert validate_model_subdirectory("model/../unsafe") is False
|
||||||
|
|
||||||
|
|
||||||
def test_subdirectory_with_slash():
|
def test_subdirectory_with_slash():
|
||||||
assert validate_model_subdirectory("model/unsafe") is False
|
assert validate_model_subdirectory("model/unsafe") is False
|
||||||
|
|
||||||
|
|
||||||
def test_subdirectory_with_special_characters():
|
def test_subdirectory_with_special_characters():
|
||||||
assert validate_model_subdirectory("model@unsafe") is False
|
assert validate_model_subdirectory("model@unsafe") is False
|
||||||
|
|
||||||
|
|
||||||
def test_subdirectory_with_underscore_and_dash():
|
def test_subdirectory_with_underscore_and_dash():
|
||||||
assert validate_model_subdirectory("valid_model-name") is True
|
assert validate_model_subdirectory("valid_model-name") is True
|
||||||
|
|
||||||
|
|
||||||
def test_empty_subdirectory():
|
def test_empty_subdirectory():
|
||||||
assert validate_model_subdirectory("") is False
|
assert validate_model_subdirectory("") is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("filename, expected", [
|
@pytest.mark.parametrize("filename, expected", [
|
||||||
("valid_model.safetensors", True),
|
("valid_model.safetensors", True),
|
||||||
("valid_model.sft", True),
|
("valid_model.sft", True),
|
||||||
("valid model.safetensors", True), # Test with space
|
("valid model.safetensors", True), # Test with space
|
||||||
("UPPERCASE_MODEL.SAFETENSORS", True),
|
("UPPERCASE_MODEL.SAFETENSORS", True),
|
||||||
("model_with.multiple.dots.pt", False),
|
("model_with.multiple.dots.pt", False),
|
||||||
("", False), # Empty string
|
("", False), # Empty string
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user