fix pylint issues

This commit is contained in:
doctorpangloss 2025-07-14 17:44:43 -07:00
parent 499f9be5fa
commit c086c5e005
26 changed files with 62 additions and 57 deletions

View File

@ -82,7 +82,7 @@ async def __execute_prompt(
try: try:
prompt_mut = make_mutable(prompt) prompt_mut = make_mutable(prompt)
from ..cmd.execution import validate_prompt from ..cmd.execution import validate_prompt
validation_tuple = validate_prompt(prompt_mut) validation_tuple = await validate_prompt(prompt_id, prompt_mut)
if not validation_tuple.valid: if not validation_tuple.valid:
if validation_tuple.node_errors is not None and len(validation_tuple.node_errors) > 0: if validation_tuple.node_errors is not None and len(validation_tuple.node_errors) > 0:
validation_error_dict = validation_tuple.node_errors validation_error_dict = validation_tuple.node_errors

View File

@ -388,7 +388,7 @@ def format_value(x) -> FormattedValue:
return str(x.__class__) return str(x.__class__)
def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, _node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple: async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, _node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
""" """
:param server: :param server:
@ -403,7 +403,7 @@ def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches,
:return: :return:
""" """
with context_execute_node(_node_id): with context_execute_node(_node_id):
return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple: async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:

View File

@ -859,8 +859,8 @@ class PromptServer(ExecutorToClientProgress):
return web.Response(status=400, reason="no prompt was specified") return web.Response(status=400, reason="no prompt was specified")
content_digest = digest(prompt_dict) content_digest = digest(prompt_dict)
task_id = str(uuid.uuid4())
valid = execution.validate_prompt(prompt_dict) valid = await execution.validate_prompt(task_id, prompt_dict)
if not valid[0]: if not valid[0]:
return web.Response(status=400, content_type="application/json", body=json.dumps(valid[1])) return web.Response(status=400, content_type="application/json", body=json.dumps(valid[1]))
@ -872,7 +872,6 @@ class PromptServer(ExecutorToClientProgress):
completed: Future[TaskInvocation | dict] = self.loop.create_future() completed: Future[TaskInvocation | dict] = self.loop.create_future()
# todo: actually implement idempotency keys # todo: actually implement idempotency keys
# we would need some kind of more durable, distributed task queue # we would need some kind of more durable, distributed task queue
task_id = str(uuid.uuid4())
item = QueueItem(queue_tuple=(number, task_id, prompt_dict, {}, valid[2]), completed=completed) item = QueueItem(queue_tuple=(number, task_id, prompt_dict, {}, valid[2]), completed=completed)
try: try:

View File

@ -110,7 +110,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
async def _callee_do_work_item(self, request: dict) -> dict: async def _callee_do_work_item(self, request: dict) -> dict:
assert self._is_callee assert self._is_callee
request_obj = RpcRequest.from_dict(request) request_obj = RpcRequest.from_dict(request)
item = request_obj.as_queue_tuple().queue_tuple item = (await request_obj.as_queue_tuple()).queue_tuple
item_with_completer = QueueItem(item, self._loop.create_future()) item_with_completer = QueueItem(item, self._loop.create_future())
self._callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer self._callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer
# todo: check if we have the local model content needed to execute this request and if not, reject it # todo: check if we have the local model content needed to execute this request and if not, reject it

View File

@ -26,13 +26,13 @@ class DistributedBase:
class RpcRequest(DistributedBase): class RpcRequest(DistributedBase):
prompt: dict | PromptDict prompt: dict | PromptDict
def as_queue_tuple(self) -> NamedQueueTuple: async def as_queue_tuple(self) -> NamedQueueTuple:
# this loads the nodes in this instance # this loads the nodes in this instance
# should always be okay to call in an executor # should always be okay to call in an executor
from ..cmd.execution import validate_prompt from ..cmd.execution import validate_prompt
from ..component_model.make_mutable import make_mutable from ..component_model.make_mutable import make_mutable
mutated_prompt_dict = make_mutable(self.prompt) mutated_prompt_dict = make_mutable(self.prompt)
validation_tuple = validate_prompt(mutated_prompt_dict) validation_tuple = await validate_prompt(self.prompt_id, mutated_prompt_dict)
return NamedQueueTuple(queue_tuple=(0, self.prompt_id, mutated_prompt_dict, {}, validation_tuple[2])) return NamedQueueTuple(queue_tuple=(0, self.prompt_id, mutated_prompt_dict, {}, validation_tuple[2]))
@classmethod @classmethod

View File

@ -1724,8 +1724,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
x_pred = x # x: current state, x_pred: predicted next state x_pred = x # x: current state, x_pred: predicted next state
h = 0.0 h = 0.0
tau_t = 0.0 tau_t = torch.float(0.0)
noise = 0.0 noise = torch.float(0.0)
pred_list = [] pred_list = []
# Lower order near the end to improve stability # Lower order near the end to improve stability

View File

@ -1,9 +1,12 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py # Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py
import torch
from .autoencoder_dc import AutoencoderDC
import logging import logging
import torch
from .autoencoder_dc import AutoencoderDC
try: try:
import torchaudio import torchaudio # pylint: disable=import-error
except: except:
logging.warning("torchaudio missing, ACE model will be broken") logging.warning("torchaudio missing, ACE model will be broken")

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
import logging import logging
try: try:
from torchaudio.transforms import MelScale from torchaudio.transforms import MelScale # pylint: disable=import-error
except: except:
logging.warning("torchaudio missing, ACE model will be broken") logging.warning("torchaudio missing, ACE model will be broken")

View File

@ -1310,7 +1310,7 @@ class Omnigen2(supported_models_base.BASE):
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref)) hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(omnigen2.LuminaTokenizer, omnigen2.te(**hunyuan_detect)) return supported_models_base.ClipTarget(omnigen2.Omnigen2Tokenizer, omnigen2.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2] models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from av.container import InputContainer from av.container import InputContainer
from av.subtitles.stream import SubtitleStream from av.subtitles.stream import SubtitleStream # pylint: disable=no-name-in-module
from fractions import Fraction from fractions import Fraction
from typing import Optional from typing import Optional
from comfy_api.input import AudioInput from comfy_api.input import AudioInput

View File

@ -16,7 +16,7 @@ from comfy_api_nodes.apis.client import (
UploadRequest, UploadRequest,
UploadResponse, UploadResponse,
) )
from server import PromptServer from comfy.cmd.server import PromptServer
import numpy as np import numpy as np

View File

@ -103,7 +103,7 @@ from urllib.parse import urljoin, urlparse
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import uuid # For generating unique operation IDs import uuid # For generating unique operation IDs
from server import PromptServer from comfy.cmd.server import PromptServer
from comfy.cli_args import args from comfy.cli_args import args
from comfy import utils from comfy import utils
from . import request_logger from . import request_logger

View File

@ -2,7 +2,7 @@ import os
import datetime import datetime
import json import json
import logging import logging
import folder_paths from comfy.cmd import folder_paths
# Get the logger instance # Get the logger instance
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -32,7 +32,7 @@ import requests
import torch import torch
import base64 import base64
import time import time
from server import PromptServer from comfy.cmd.server import PromptServer
def convert_mask_to_image(mask: torch.Tensor): def convert_mask_to_image(mask: torch.Tensor):

View File

@ -9,9 +9,9 @@ from typing import Optional, Literal
import torch import torch
import folder_paths from comfy.cmd import folder_paths
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from server import PromptServer from comfy.cmd.server import PromptServer
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import (
GeminiContent, GeminiContent,
GeminiGenerateContentRequest, GeminiGenerateContentRequest,
@ -406,7 +406,7 @@ class GeminiInputFiles(ComfyNodeABC):
def create_file_part(self, file_path: str) -> GeminiPart: def create_file_part(self, file_path: str) -> GeminiPart:
mime_type = ( mime_type = (
GeminiMimeType.pdf GeminiMimeType.application_pdf
if file_path.endswith(".pdf") if file_path.endswith(".pdf")
else GeminiMimeType.text_plain else GeminiMimeType.text_plain
) )

View File

@ -23,7 +23,7 @@ from comfy_api_nodes.apinode_utils import (
bytesio_to_image_tensor, bytesio_to_image_tensor,
resize_mask_to_image, resize_mask_to_image,
) )
from server import PromptServer from comfy.cmd.server import PromptServer
V1_V1_RES_MAP = { V1_V1_RES_MAP = {
"Auto":"AUTO", "Auto":"AUTO",

View File

@ -36,7 +36,7 @@ from comfy_api_nodes.apinode_utils import (
process_image_response, process_image_response,
validate_string, validate_string,
) )
from server import PromptServer from comfy.cmd.server import PromptServer
import requests import requests
import torch import torch

View File

@ -24,7 +24,7 @@ from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi, upload_images_to_comfyapi,
validate_string, validate_string,
) )
from server import PromptServer from comfy.cmd.server import PromptServer
I2V_AVERAGE_DURATION = 114 I2V_AVERAGE_DURATION = 114

View File

@ -11,8 +11,8 @@ import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from server import PromptServer from comfy.cmd.server import PromptServer
import folder_paths from comfy.cmd import folder_paths
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import (

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from inspect import cleandoc from inspect import cleandoc
from typing import Optional from typing import Optional
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
from comfy_extras.nodes_images import SVG # Added from comfy_extras.nodes.nodes_images import SVG # Added
from comfy.comfy_types.node_typing import IO from comfy.comfy_types.node_typing import IO
from comfy_api_nodes.apis.recraft_api import ( from comfy_api_nodes.apis.recraft_api import (
RecraftImageGenerationRequest, RecraftImageGenerationRequest,
@ -30,7 +30,7 @@ from comfy_api_nodes.apinode_utils import (
resize_mask_to_image, resize_mask_to_image,
validate_string, validate_string,
) )
from server import PromptServer from comfy.cmd.server import PromptServer
import torch import torch
from io import BytesIO from io import BytesIO

View File

@ -8,7 +8,7 @@ Rodin API docs: https://developer.hyper3d.ai/
from __future__ import annotations from __future__ import annotations
from inspect import cleandoc from inspect import cleandoc
from comfy.comfy_types.node_typing import IO from comfy.comfy_types.node_typing import IO
import folder_paths as comfy_paths from comfy.cmd import folder_paths as comfy_paths
import requests import requests
import os import os
import datetime import datetime

View File

@ -1,5 +1,5 @@
import os import os
from folder_paths import get_output_directory from comfy.cmd.folder_paths import get_output_directory
from comfy_api_nodes.mapper_utils import model_field_to_node_input from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy.comfy_types.node_typing import IO from comfy.comfy_types.node_typing import IO
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import (

View File

@ -113,9 +113,9 @@ def extract_node_configuration(path) -> Optional[PyProjectConfig]:
project_data = raw_settings.project project_data = raw_settings.project
tool_data = raw_settings.tool tool_data = raw_settings.tool
comfy_data = tool_data.get("comfy", {}) if tool_data else {} comfy_data = tool_data.get("comfy", {}) if tool_data else {} # pylint: disable=no-member
dependencies = project_data.get("dependencies", []) dependencies = project_data.get("dependencies", []) # pylint: disable=no-member
supported_comfyui_frontend_version = "" supported_comfyui_frontend_version = ""
for dep in dependencies: for dep in dependencies:
if isinstance(dep, str) and dep.startswith("comfyui-frontend-package"): if isinstance(dep, str) and dep.startswith("comfyui-frontend-package"):
@ -124,7 +124,7 @@ def extract_node_configuration(path) -> Optional[PyProjectConfig]:
supported_comfyui_version = comfy_data.get("requires-comfyui", "") supported_comfyui_version = comfy_data.get("requires-comfyui", "")
classifiers = project_data.get('classifiers', []) classifiers = project_data.get('classifiers', []) # pylint: disable=no-member
supported_os = validate_and_extract_os_classifiers(classifiers) supported_os = validate_and_extract_os_classifiers(classifiers)
supported_accelerators = validate_and_extract_accelerator_classifiers(classifiers) supported_accelerators = validate_and_extract_accelerator_classifiers(classifiers)

View File

@ -244,7 +244,7 @@ class ExecutionList(TopologicalSort):
# This will execute the asynchronous function earlier, reducing the overall time. # This will execute the asynchronous function earlier, reducing the overall time.
def is_async(node_id): def is_async(node_id):
class_type = self.dynprompt.get_node(node_id)["class_type"] class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION)) return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
for node_id in node_list: for node_id in node_list:

View File

@ -23,12 +23,13 @@ from comfy.distributed.process_pool_executor import ProcessPoolExecutor
from comfy.distributed.server_stub import ServerStub from comfy.distributed.server_stub import ServerStub
def create_test_prompt() -> QueueItem: async def create_test_prompt() -> QueueItem:
from comfy.cmd.execution import validate_prompt from comfy.cmd.execution import validate_prompt
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)) prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
validation_tuple = validate_prompt(prompt)
item_id = str(uuid.uuid4()) item_id = str(uuid.uuid4())
validation_tuple = await validate_prompt(item_id, prompt)
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2]) queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
return QueueItem(queue_tuple, None) return QueueItem(queue_tuple, None)
@ -55,7 +56,7 @@ async def test_basic_queue_worker(executor_factory: Callable[..., Executor]) ->
# now submit some jobs # now submit some jobs
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
await distributed_queue.init() await distributed_queue.init()
queue_item = create_test_prompt() queue_item = await create_test_prompt()
res: TaskInvocation = await distributed_queue.put_async(queue_item) res: TaskInvocation = await distributed_queue.put_async(queue_item)
assert res.item_id == queue_item.prompt_id assert res.item_id == queue_item.prompt_id
assert len(res.outputs) == 1 assert len(res.outputs) == 1
@ -73,7 +74,7 @@ async def test_distributed_prompt_queues_same_process():
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend: async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker: async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
test_prompt = create_test_prompt() test_prompt = await create_test_prompt()
test_prompt.completed = asyncio.Future() test_prompt.completed = asyncio.Future()
frontend.put(test_prompt) frontend.put(test_prompt)
@ -224,8 +225,8 @@ async def test_two_workers_distinct_requests():
await queue.init() await queue.init()
# Submit two prompts # Submit two prompts
task1 = asyncio.create_task(queue.put_async(create_test_prompt())) task1 = asyncio.create_task(queue.put_async(await create_test_prompt()))
task2 = asyncio.create_task(queue.put_async(create_test_prompt())) task2 = asyncio.create_task(queue.put_async(await create_test_prompt()))
# Wait for tasks to complete # Wait for tasks to complete
await asyncio.gather(task1, task2) await asyncio.gather(task1, task2)

View File

@ -8,6 +8,8 @@ from comfy.cli_args import args
from comfy.cmd.execution import validate_prompt from comfy.cmd.execution import validate_prompt
from comfy.nodes_context import nodes from comfy.nodes_context import nodes
import uuid
valid_prompt: Final[dict] = { valid_prompt: Final[dict] = {
"1": { "1": {
"inputs": { "inputs": {
@ -153,15 +155,15 @@ def disable_known_models():
args.disable_known_models = original_value args.disable_known_models = original_value
def test_validate_prompt_valid(mock_nodes): async def test_validate_prompt_valid(mock_nodes):
prompt = valid_prompt prompt = valid_prompt
result = validate_prompt(prompt) result = await validate_prompt(str(uuid.uuid4()), prompt)
assert result.valid assert result.valid
assert result.error is None assert result.error is None
assert set(result.good_output_node_ids) == {"7"} assert set(result.good_output_node_ids) == {"7"}
def test_validate_prompt_invalid_node(mock_nodes): async def test_validate_prompt_invalid_node(mock_nodes):
prompt = { prompt = {
"1": { "1": {
"inputs": {}, "inputs": {},
@ -169,13 +171,13 @@ def test_validate_prompt_invalid_node(mock_nodes):
}, },
} }
result = validate_prompt(prompt) result = await validate_prompt(str(uuid.uuid4()), prompt)
assert not result.valid assert not result.valid
assert result.error["type"] == "invalid_prompt" assert result.error["type"] == "invalid_prompt"
assert "NonExistentNode" in result.error["message"] assert "NonExistentNode" in result.error["message"]
def test_prompt_has_no_output(mock_nodes): async def test_prompt_has_no_output(mock_nodes):
prompt = { prompt = {
"1": { "1": {
"inputs": {}, "inputs": {},
@ -183,12 +185,12 @@ def test_prompt_has_no_output(mock_nodes):
}, },
} }
result = validate_prompt(prompt) result = await validate_prompt(str(uuid.uuid4()), prompt)
assert not result.valid assert not result.valid
assert result.error["type"] == "prompt_no_outputs" assert result.error["type"] == "prompt_no_outputs"
def test_validate_prompt_invalid_input_type(mock_nodes): async def test_validate_prompt_invalid_input_type(mock_nodes):
prompt = valid_prompt.copy() prompt = valid_prompt.copy()
prompt["1"] = { prompt["1"] = {
"inputs": { "inputs": {
@ -197,7 +199,7 @@ def test_validate_prompt_invalid_input_type(mock_nodes):
"class_type": "CheckpointLoaderSimple", "class_type": "CheckpointLoaderSimple",
} }
result = validate_prompt(prompt) result = await validate_prompt(str(uuid.uuid4()), prompt)
assert not result.valid assert not result.valid
assert result.error["type"] == "prompt_outputs_failed_validation" assert result.error["type"] == "prompt_outputs_failed_validation"
assert result.node_errors["1"]["errors"][0]["type"] == "value_not_in_list" assert result.node_errors["1"]["errors"][0]["type"] == "value_not_in_list"
@ -212,7 +214,7 @@ def test_validate_prompt_invalid_input_type(mock_nodes):
("C:\\Windows\\Temp\\model.safetensors", "C:/Windows/Temp/model.safetensors"), ("C:\\Windows\\Temp\\model.safetensors", "C:/Windows/Temp/model.safetensors"),
("/home/user/models/model.safetensors", "/home/user/models/model.safetensors"), ("/home/user/models/model.safetensors", "/home/user/models/model.safetensors"),
]) ])
def test_validate_prompt_path_variations(mock_nodes, disable_known_models, ckpt_name, known_model): async def test_validate_prompt_path_variations(mock_nodes, disable_known_models, ckpt_name, known_model):
token = known_models.set([known_model]) token = known_models.set([known_model])
try: try:
@ -224,23 +226,23 @@ def test_validate_prompt_path_variations(mock_nodes, disable_known_models, ckpt_
"class_type": "CheckpointLoaderSimple", "class_type": "CheckpointLoaderSimple",
} }
result = validate_prompt(prompt) result = await validate_prompt(str(uuid.uuid4()), prompt)
assert result.valid, f"Failed for ckpt_name: {ckpt_name}, known_model: {known_model}" assert result.valid, f"Failed for ckpt_name: {ckpt_name}, known_model: {known_model}"
assert result.error is None, f"Error for ckpt_name: {ckpt_name}, known_model: {known_model}" assert result.error is None, f"Error for ckpt_name: {ckpt_name}, known_model: {known_model}"
finally: finally:
known_models.reset(token) known_models.reset(token)
def test_validate_prompt_default_models(mock_nodes, disable_known_models): async def test_validate_prompt_default_models(mock_nodes, disable_known_models):
prompt = valid_prompt.copy() prompt = valid_prompt.copy()
prompt["1"]["inputs"]["ckpt_name"] = "model1.safetensors" prompt["1"]["inputs"]["ckpt_name"] = "model1.safetensors"
result = validate_prompt(prompt) result = await validate_prompt(str(uuid.uuid4()), prompt)
assert result.valid, "Failed for default model list" assert result.valid, "Failed for default model list"
assert result.error is None, "Error for default model list" assert result.error is None, "Error for default model list"
def test_validate_prompt_no_outputs(mock_nodes): async def test_validate_prompt_no_outputs(mock_nodes):
prompt = { prompt = {
"1": { "1": {
"inputs": { "inputs": {
@ -250,6 +252,6 @@ def test_validate_prompt_no_outputs(mock_nodes):
}, },
} }
result = validate_prompt(prompt) result = await validate_prompt(str(uuid.uuid4()), prompt)
assert not result.valid assert not result.valid
assert result.error["type"] == "prompt_no_outputs" assert result.error["type"] == "prompt_no_outputs"