From c086c5e005c2a0c59979a59124c18535650636f1 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Mon, 14 Jul 2025 17:44:43 -0700 Subject: [PATCH] fix pylint issues --- comfy/client/embedded_comfy_client.py | 2 +- comfy/cmd/execution.py | 4 +-- comfy/cmd/server.py | 5 ++-- comfy/distributed/distributed_prompt_queue.py | 2 +- comfy/distributed/distributed_types.py | 4 +-- comfy/k_diffusion/sampling.py | 4 +-- comfy/ldm/ace/vae/music_dcae_pipeline.py | 9 ++++-- comfy/ldm/ace/vae/music_log_mel.py | 2 +- comfy/supported_models.py | 2 +- comfy_api/input_impl/video_types.py | 2 +- comfy_api_nodes/apinode_utils.py | 2 +- comfy_api_nodes/apis/client.py | 2 +- comfy_api_nodes/apis/request_logger.py | 2 +- comfy_api_nodes/nodes_bfl.py | 2 +- comfy_api_nodes/nodes_gemini.py | 6 ++-- comfy_api_nodes/nodes_ideogram.py | 2 +- comfy_api_nodes/nodes_luma.py | 2 +- comfy_api_nodes/nodes_minimax.py | 2 +- comfy_api_nodes/nodes_openai.py | 4 +-- comfy_api_nodes/nodes_recraft.py | 4 +-- comfy_api_nodes/nodes_rodin.py | 2 +- comfy_api_nodes/nodes_tripo.py | 2 +- comfy_config/config_parser.py | 6 ++-- comfy_execution/graph.py | 2 +- tests/distributed/test_distributed_queue.py | 13 ++++---- tests/unit/test_validation.py | 30 ++++++++++--------- 26 files changed, 62 insertions(+), 57 deletions(-) diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index e6e2fa63a..d6aaaf33f 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -82,7 +82,7 @@ async def __execute_prompt( try: prompt_mut = make_mutable(prompt) from ..cmd.execution import validate_prompt - validation_tuple = validate_prompt(prompt_mut) + validation_tuple = await validate_prompt(prompt_id, prompt_mut) if not validation_tuple.valid: if validation_tuple.node_errors is not None and len(validation_tuple.node_errors) > 0: validation_error_dict = validation_tuple.node_errors diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index cf6d1c0c2..21a8696f5 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -388,7 +388,7 @@ def format_value(x) -> FormattedValue: return str(x.__class__) -def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, _node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple: +async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, _node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple: """ :param server: @@ -403,7 +403,7 @@ def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, :return: """ with context_execute_node(_node_id): - return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) + return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple: diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 0758ceccd..7f094163a 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -859,8 +859,8 @@ class PromptServer(ExecutorToClientProgress): return web.Response(status=400, reason="no prompt was specified") content_digest = digest(prompt_dict) - - valid = execution.validate_prompt(prompt_dict) + task_id = str(uuid.uuid4()) + valid = await execution.validate_prompt(task_id, prompt_dict) if not valid[0]: return web.Response(status=400, content_type="application/json", body=json.dumps(valid[1])) @@ -872,7 +872,6 @@ class PromptServer(ExecutorToClientProgress): completed: Future[TaskInvocation | dict] = self.loop.create_future() # todo: actually implement idempotency keys # we would need some kind of more durable, distributed task queue - task_id = str(uuid.uuid4()) item = QueueItem(queue_tuple=(number, task_id, prompt_dict, {}, valid[2]), completed=completed) try: diff --git a/comfy/distributed/distributed_prompt_queue.py b/comfy/distributed/distributed_prompt_queue.py index d55fe7526..6947bb534 100644 --- a/comfy/distributed/distributed_prompt_queue.py +++ b/comfy/distributed/distributed_prompt_queue.py @@ -110,7 +110,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue): async def _callee_do_work_item(self, request: dict) -> dict: assert self._is_callee request_obj = RpcRequest.from_dict(request) - item = request_obj.as_queue_tuple().queue_tuple + item = (await request_obj.as_queue_tuple()).queue_tuple item_with_completer = QueueItem(item, self._loop.create_future()) self._callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer # todo: check if we have the local model content needed to execute this request and if not, reject it diff --git a/comfy/distributed/distributed_types.py b/comfy/distributed/distributed_types.py index 35c420d66..8c637827b 100644 --- a/comfy/distributed/distributed_types.py +++ b/comfy/distributed/distributed_types.py @@ -26,13 +26,13 @@ class DistributedBase: class RpcRequest(DistributedBase): prompt: dict | PromptDict - def as_queue_tuple(self) -> NamedQueueTuple: + async def as_queue_tuple(self) -> NamedQueueTuple: # this loads the nodes in this instance # should always be okay to call in an executor from ..cmd.execution import validate_prompt from ..component_model.make_mutable import make_mutable mutated_prompt_dict = make_mutable(self.prompt) - validation_tuple = validate_prompt(mutated_prompt_dict) + validation_tuple = await validate_prompt(self.prompt_id, mutated_prompt_dict) return NamedQueueTuple(queue_tuple=(0, self.prompt_id, mutated_prompt_dict, {}, validation_tuple[2])) @classmethod diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index fab132e07..039f08f2d 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1724,8 +1724,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F x_pred = x # x: current state, x_pred: predicted next state h = 0.0 - tau_t = 0.0 - noise = 0.0 + tau_t = torch.float(0.0) + noise = torch.float(0.0) pred_list = [] # Lower order near the end to improve stability diff --git a/comfy/ldm/ace/vae/music_dcae_pipeline.py b/comfy/ldm/ace/vae/music_dcae_pipeline.py index af81280eb..6357cd84e 100644 --- a/comfy/ldm/ace/vae/music_dcae_pipeline.py +++ b/comfy/ldm/ace/vae/music_dcae_pipeline.py @@ -1,9 +1,12 @@ # Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py -import torch -from .autoencoder_dc import AutoencoderDC import logging + +import torch + +from .autoencoder_dc import AutoencoderDC + try: - import torchaudio + import torchaudio # pylint: disable=import-error except: logging.warning("torchaudio missing, ACE model will be broken") diff --git a/comfy/ldm/ace/vae/music_log_mel.py b/comfy/ldm/ace/vae/music_log_mel.py index 50d6bff88..caa3c64d0 100755 --- a/comfy/ldm/ace/vae/music_log_mel.py +++ b/comfy/ldm/ace/vae/music_log_mel.py @@ -4,7 +4,7 @@ import torch.nn as nn from torch import Tensor import logging try: - from torchaudio.transforms import MelScale + from torchaudio.transforms import MelScale # pylint: disable=import-error except: logging.warning("torchaudio missing, ACE model will be broken") diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1f26df43d..2544ab97c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1310,7 +1310,7 @@ class Omnigen2(supported_models_base.BASE): def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0] hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref)) - return supported_models_base.ClipTarget(omnigen2.LuminaTokenizer, omnigen2.te(**hunyuan_detect)) + return supported_models_base.ClipTarget(omnigen2.Omnigen2Tokenizer, omnigen2.te(**hunyuan_detect)) models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2] diff --git a/comfy_api/input_impl/video_types.py b/comfy_api/input_impl/video_types.py index 9ae818f4e..c24f3e9a7 100644 --- a/comfy_api/input_impl/video_types.py +++ b/comfy_api/input_impl/video_types.py @@ -1,6 +1,6 @@ from __future__ import annotations from av.container import InputContainer -from av.subtitles.stream import SubtitleStream +from av.subtitles.stream import SubtitleStream # pylint: disable=no-name-in-module from fractions import Fraction from typing import Optional from comfy_api.input import AudioInput diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index 788e2803f..84c9360e2 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -16,7 +16,7 @@ from comfy_api_nodes.apis.client import ( UploadRequest, UploadResponse, ) -from server import PromptServer +from comfy.cmd.server import PromptServer import numpy as np diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py index 2a4bac88b..2049de46b 100644 --- a/comfy_api_nodes/apis/client.py +++ b/comfy_api_nodes/apis/client.py @@ -103,7 +103,7 @@ from urllib.parse import urljoin, urlparse from pydantic import BaseModel, Field import uuid # For generating unique operation IDs -from server import PromptServer +from comfy.cmd.server import PromptServer from comfy.cli_args import args from comfy import utils from . import request_logger diff --git a/comfy_api_nodes/apis/request_logger.py b/comfy_api_nodes/apis/request_logger.py index 93517ede9..8ec10a23a 100644 --- a/comfy_api_nodes/apis/request_logger.py +++ b/comfy_api_nodes/apis/request_logger.py @@ -2,7 +2,7 @@ import os import datetime import json import logging -import folder_paths +from comfy.cmd import folder_paths # Get the logger instance logger = logging.getLogger(__name__) diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index d93fbd778..62b5fb9e4 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -32,7 +32,7 @@ import requests import torch import base64 import time -from server import PromptServer +from comfy.cmd.server import PromptServer def convert_mask_to_image(mask: torch.Tensor): diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index ae7b04846..98c22007c 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -9,9 +9,9 @@ from typing import Optional, Literal import torch -import folder_paths +from comfy.cmd import folder_paths from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict -from server import PromptServer +from comfy.cmd.server import PromptServer from comfy_api_nodes.apis import ( GeminiContent, GeminiGenerateContentRequest, @@ -406,7 +406,7 @@ class GeminiInputFiles(ComfyNodeABC): def create_file_part(self, file_path: str) -> GeminiPart: mime_type = ( - GeminiMimeType.pdf + GeminiMimeType.application_pdf if file_path.endswith(".pdf") else GeminiMimeType.text_plain ) diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index b8487355f..dcaa43986 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -23,7 +23,7 @@ from comfy_api_nodes.apinode_utils import ( bytesio_to_image_tensor, resize_mask_to_image, ) -from server import PromptServer +from comfy.cmd.server import PromptServer V1_V1_RES_MAP = { "Auto":"AUTO", diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 525dc38e6..2f8f84040 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -36,7 +36,7 @@ from comfy_api_nodes.apinode_utils import ( process_image_response, validate_string, ) -from server import PromptServer +from comfy.cmd.server import PromptServer import requests import torch diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 9b46636db..870c0e0e2 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -24,7 +24,7 @@ from comfy_api_nodes.apinode_utils import ( upload_images_to_comfyapi, validate_string, ) -from server import PromptServer +from comfy.cmd.server import PromptServer I2V_AVERAGE_DURATION = 114 diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index be1d2de4a..1460e3bf7 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -11,8 +11,8 @@ import numpy as np import torch from PIL import Image from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict -from server import PromptServer -import folder_paths +from comfy.cmd.server import PromptServer +from comfy.cmd import folder_paths from comfy_api_nodes.apis import ( diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index e369c4b7e..dea600b93 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -2,7 +2,7 @@ from __future__ import annotations from inspect import cleandoc from typing import Optional from comfy.utils import ProgressBar -from comfy_extras.nodes_images import SVG # Added +from comfy_extras.nodes.nodes_images import SVG # Added from comfy.comfy_types.node_typing import IO from comfy_api_nodes.apis.recraft_api import ( RecraftImageGenerationRequest, @@ -30,7 +30,7 @@ from comfy_api_nodes.apinode_utils import ( resize_mask_to_image, validate_string, ) -from server import PromptServer +from comfy.cmd.server import PromptServer import torch from io import BytesIO diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 67f90478c..f8f9bbbff 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -8,7 +8,7 @@ Rodin API docs: https://developer.hyper3d.ai/ from __future__ import annotations from inspect import cleandoc from comfy.comfy_types.node_typing import IO -import folder_paths as comfy_paths +from comfy.cmd import folder_paths as comfy_paths import requests import os import datetime diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index 65f3b21f5..b6bbb50b8 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -1,5 +1,5 @@ import os -from folder_paths import get_output_directory +from comfy.cmd.folder_paths import get_output_directory from comfy_api_nodes.mapper_utils import model_field_to_node_input from comfy.comfy_types.node_typing import IO from comfy_api_nodes.apis import ( diff --git a/comfy_config/config_parser.py b/comfy_config/config_parser.py index 8da7bd901..a7470bcea 100644 --- a/comfy_config/config_parser.py +++ b/comfy_config/config_parser.py @@ -113,9 +113,9 @@ def extract_node_configuration(path) -> Optional[PyProjectConfig]: project_data = raw_settings.project tool_data = raw_settings.tool - comfy_data = tool_data.get("comfy", {}) if tool_data else {} + comfy_data = tool_data.get("comfy", {}) if tool_data else {} # pylint: disable=no-member - dependencies = project_data.get("dependencies", []) + dependencies = project_data.get("dependencies", []) # pylint: disable=no-member supported_comfyui_frontend_version = "" for dep in dependencies: if isinstance(dep, str) and dep.startswith("comfyui-frontend-package"): @@ -124,7 +124,7 @@ def extract_node_configuration(path) -> Optional[PyProjectConfig]: supported_comfyui_version = comfy_data.get("requires-comfyui", "") - classifiers = project_data.get('classifiers', []) + classifiers = project_data.get('classifiers', []) # pylint: disable=no-member supported_os = validate_and_extract_os_classifiers(classifiers) supported_accelerators = validate_and_extract_accelerator_classifiers(classifiers) diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 2d291c247..7552264ba 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -244,7 +244,7 @@ class ExecutionList(TopologicalSort): # This will execute the asynchronous function earlier, reducing the overall time. def is_async(node_id): class_type = self.dynprompt.get_node(node_id)["class_type"] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type] return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION)) for node_id in node_list: diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index 927acebf3..5aa28c722 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -23,12 +23,13 @@ from comfy.distributed.process_pool_executor import ProcessPoolExecutor from comfy.distributed.server_stub import ServerStub -def create_test_prompt() -> QueueItem: +async def create_test_prompt() -> QueueItem: from comfy.cmd.execution import validate_prompt prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)) - validation_tuple = validate_prompt(prompt) item_id = str(uuid.uuid4()) + + validation_tuple = await validate_prompt(item_id, prompt) queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2]) return QueueItem(queue_tuple, None) @@ -55,7 +56,7 @@ async def test_basic_queue_worker(executor_factory: Callable[..., Executor]) -> # now submit some jobs distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") await distributed_queue.init() - queue_item = create_test_prompt() + queue_item = await create_test_prompt() res: TaskInvocation = await distributed_queue.put_async(queue_item) assert res.item_id == queue_item.prompt_id assert len(res.outputs) == 1 @@ -73,7 +74,7 @@ async def test_distributed_prompt_queues_same_process(): from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend: async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker: - test_prompt = create_test_prompt() + test_prompt = await create_test_prompt() test_prompt.completed = asyncio.Future() frontend.put(test_prompt) @@ -224,8 +225,8 @@ async def test_two_workers_distinct_requests(): await queue.init() # Submit two prompts - task1 = asyncio.create_task(queue.put_async(create_test_prompt())) - task2 = asyncio.create_task(queue.put_async(create_test_prompt())) + task1 = asyncio.create_task(queue.put_async(await create_test_prompt())) + task2 = asyncio.create_task(queue.put_async(await create_test_prompt())) # Wait for tasks to complete await asyncio.gather(task1, task2) diff --git a/tests/unit/test_validation.py b/tests/unit/test_validation.py index a0fbf5d1a..77efef87e 100644 --- a/tests/unit/test_validation.py +++ b/tests/unit/test_validation.py @@ -8,6 +8,8 @@ from comfy.cli_args import args from comfy.cmd.execution import validate_prompt from comfy.nodes_context import nodes +import uuid + valid_prompt: Final[dict] = { "1": { "inputs": { @@ -153,15 +155,15 @@ def disable_known_models(): args.disable_known_models = original_value -def test_validate_prompt_valid(mock_nodes): +async def test_validate_prompt_valid(mock_nodes): prompt = valid_prompt - result = validate_prompt(prompt) + result = await validate_prompt(str(uuid.uuid4()), prompt) assert result.valid assert result.error is None assert set(result.good_output_node_ids) == {"7"} -def test_validate_prompt_invalid_node(mock_nodes): +async def test_validate_prompt_invalid_node(mock_nodes): prompt = { "1": { "inputs": {}, @@ -169,13 +171,13 @@ def test_validate_prompt_invalid_node(mock_nodes): }, } - result = validate_prompt(prompt) + result = await validate_prompt(str(uuid.uuid4()), prompt) assert not result.valid assert result.error["type"] == "invalid_prompt" assert "NonExistentNode" in result.error["message"] -def test_prompt_has_no_output(mock_nodes): +async def test_prompt_has_no_output(mock_nodes): prompt = { "1": { "inputs": {}, @@ -183,12 +185,12 @@ def test_prompt_has_no_output(mock_nodes): }, } - result = validate_prompt(prompt) + result = await validate_prompt(str(uuid.uuid4()), prompt) assert not result.valid assert result.error["type"] == "prompt_no_outputs" -def test_validate_prompt_invalid_input_type(mock_nodes): +async def test_validate_prompt_invalid_input_type(mock_nodes): prompt = valid_prompt.copy() prompt["1"] = { "inputs": { @@ -197,7 +199,7 @@ def test_validate_prompt_invalid_input_type(mock_nodes): "class_type": "CheckpointLoaderSimple", } - result = validate_prompt(prompt) + result = await validate_prompt(str(uuid.uuid4()), prompt) assert not result.valid assert result.error["type"] == "prompt_outputs_failed_validation" assert result.node_errors["1"]["errors"][0]["type"] == "value_not_in_list" @@ -212,7 +214,7 @@ def test_validate_prompt_invalid_input_type(mock_nodes): ("C:\\Windows\\Temp\\model.safetensors", "C:/Windows/Temp/model.safetensors"), ("/home/user/models/model.safetensors", "/home/user/models/model.safetensors"), ]) -def test_validate_prompt_path_variations(mock_nodes, disable_known_models, ckpt_name, known_model): +async def test_validate_prompt_path_variations(mock_nodes, disable_known_models, ckpt_name, known_model): token = known_models.set([known_model]) try: @@ -224,23 +226,23 @@ def test_validate_prompt_path_variations(mock_nodes, disable_known_models, ckpt_ "class_type": "CheckpointLoaderSimple", } - result = validate_prompt(prompt) + result = await validate_prompt(str(uuid.uuid4()), prompt) assert result.valid, f"Failed for ckpt_name: {ckpt_name}, known_model: {known_model}" assert result.error is None, f"Error for ckpt_name: {ckpt_name}, known_model: {known_model}" finally: known_models.reset(token) -def test_validate_prompt_default_models(mock_nodes, disable_known_models): +async def test_validate_prompt_default_models(mock_nodes, disable_known_models): prompt = valid_prompt.copy() prompt["1"]["inputs"]["ckpt_name"] = "model1.safetensors" - result = validate_prompt(prompt) + result = await validate_prompt(str(uuid.uuid4()), prompt) assert result.valid, "Failed for default model list" assert result.error is None, "Error for default model list" -def test_validate_prompt_no_outputs(mock_nodes): +async def test_validate_prompt_no_outputs(mock_nodes): prompt = { "1": { "inputs": { @@ -250,6 +252,6 @@ def test_validate_prompt_no_outputs(mock_nodes): }, } - result = validate_prompt(prompt) + result = await validate_prompt(str(uuid.uuid4()), prompt) assert not result.valid assert result.error["type"] == "prompt_no_outputs"