mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 13:50:49 +08:00
fix pylint issues
This commit is contained in:
parent
499f9be5fa
commit
c086c5e005
@ -82,7 +82,7 @@ async def __execute_prompt(
|
|||||||
try:
|
try:
|
||||||
prompt_mut = make_mutable(prompt)
|
prompt_mut = make_mutable(prompt)
|
||||||
from ..cmd.execution import validate_prompt
|
from ..cmd.execution import validate_prompt
|
||||||
validation_tuple = validate_prompt(prompt_mut)
|
validation_tuple = await validate_prompt(prompt_id, prompt_mut)
|
||||||
if not validation_tuple.valid:
|
if not validation_tuple.valid:
|
||||||
if validation_tuple.node_errors is not None and len(validation_tuple.node_errors) > 0:
|
if validation_tuple.node_errors is not None and len(validation_tuple.node_errors) > 0:
|
||||||
validation_error_dict = validation_tuple.node_errors
|
validation_error_dict = validation_tuple.node_errors
|
||||||
|
|||||||
@ -388,7 +388,7 @@ def format_value(x) -> FormattedValue:
|
|||||||
return str(x.__class__)
|
return str(x.__class__)
|
||||||
|
|
||||||
|
|
||||||
def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, _node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple:
|
async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, _node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param server:
|
:param server:
|
||||||
@ -403,7 +403,7 @@ def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches,
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
with context_execute_node(_node_id):
|
with context_execute_node(_node_id):
|
||||||
return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
||||||
|
|
||||||
|
|
||||||
async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
|
async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
|
||||||
|
|||||||
@ -859,8 +859,8 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
return web.Response(status=400, reason="no prompt was specified")
|
return web.Response(status=400, reason="no prompt was specified")
|
||||||
|
|
||||||
content_digest = digest(prompt_dict)
|
content_digest = digest(prompt_dict)
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
valid = execution.validate_prompt(prompt_dict)
|
valid = await execution.validate_prompt(task_id, prompt_dict)
|
||||||
if not valid[0]:
|
if not valid[0]:
|
||||||
return web.Response(status=400, content_type="application/json", body=json.dumps(valid[1]))
|
return web.Response(status=400, content_type="application/json", body=json.dumps(valid[1]))
|
||||||
|
|
||||||
@ -872,7 +872,6 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
completed: Future[TaskInvocation | dict] = self.loop.create_future()
|
completed: Future[TaskInvocation | dict] = self.loop.create_future()
|
||||||
# todo: actually implement idempotency keys
|
# todo: actually implement idempotency keys
|
||||||
# we would need some kind of more durable, distributed task queue
|
# we would need some kind of more durable, distributed task queue
|
||||||
task_id = str(uuid.uuid4())
|
|
||||||
item = QueueItem(queue_tuple=(number, task_id, prompt_dict, {}, valid[2]), completed=completed)
|
item = QueueItem(queue_tuple=(number, task_id, prompt_dict, {}, valid[2]), completed=completed)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -110,7 +110,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
|
|||||||
async def _callee_do_work_item(self, request: dict) -> dict:
|
async def _callee_do_work_item(self, request: dict) -> dict:
|
||||||
assert self._is_callee
|
assert self._is_callee
|
||||||
request_obj = RpcRequest.from_dict(request)
|
request_obj = RpcRequest.from_dict(request)
|
||||||
item = request_obj.as_queue_tuple().queue_tuple
|
item = (await request_obj.as_queue_tuple()).queue_tuple
|
||||||
item_with_completer = QueueItem(item, self._loop.create_future())
|
item_with_completer = QueueItem(item, self._loop.create_future())
|
||||||
self._callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer
|
self._callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer
|
||||||
# todo: check if we have the local model content needed to execute this request and if not, reject it
|
# todo: check if we have the local model content needed to execute this request and if not, reject it
|
||||||
|
|||||||
@ -26,13 +26,13 @@ class DistributedBase:
|
|||||||
class RpcRequest(DistributedBase):
|
class RpcRequest(DistributedBase):
|
||||||
prompt: dict | PromptDict
|
prompt: dict | PromptDict
|
||||||
|
|
||||||
def as_queue_tuple(self) -> NamedQueueTuple:
|
async def as_queue_tuple(self) -> NamedQueueTuple:
|
||||||
# this loads the nodes in this instance
|
# this loads the nodes in this instance
|
||||||
# should always be okay to call in an executor
|
# should always be okay to call in an executor
|
||||||
from ..cmd.execution import validate_prompt
|
from ..cmd.execution import validate_prompt
|
||||||
from ..component_model.make_mutable import make_mutable
|
from ..component_model.make_mutable import make_mutable
|
||||||
mutated_prompt_dict = make_mutable(self.prompt)
|
mutated_prompt_dict = make_mutable(self.prompt)
|
||||||
validation_tuple = validate_prompt(mutated_prompt_dict)
|
validation_tuple = await validate_prompt(self.prompt_id, mutated_prompt_dict)
|
||||||
return NamedQueueTuple(queue_tuple=(0, self.prompt_id, mutated_prompt_dict, {}, validation_tuple[2]))
|
return NamedQueueTuple(queue_tuple=(0, self.prompt_id, mutated_prompt_dict, {}, validation_tuple[2]))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -1724,8 +1724,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
|||||||
x_pred = x # x: current state, x_pred: predicted next state
|
x_pred = x # x: current state, x_pred: predicted next state
|
||||||
|
|
||||||
h = 0.0
|
h = 0.0
|
||||||
tau_t = 0.0
|
tau_t = torch.float(0.0)
|
||||||
noise = 0.0
|
noise = torch.float(0.0)
|
||||||
pred_list = []
|
pred_list = []
|
||||||
|
|
||||||
# Lower order near the end to improve stability
|
# Lower order near the end to improve stability
|
||||||
|
|||||||
@ -1,9 +1,12 @@
|
|||||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py
|
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py
|
||||||
import torch
|
|
||||||
from .autoencoder_dc import AutoencoderDC
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .autoencoder_dc import AutoencoderDC
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torchaudio
|
import torchaudio # pylint: disable=import-error
|
||||||
except:
|
except:
|
||||||
logging.warning("torchaudio missing, ACE model will be broken")
|
logging.warning("torchaudio missing, ACE model will be broken")
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import logging
|
import logging
|
||||||
try:
|
try:
|
||||||
from torchaudio.transforms import MelScale
|
from torchaudio.transforms import MelScale # pylint: disable=import-error
|
||||||
except:
|
except:
|
||||||
logging.warning("torchaudio missing, ACE model will be broken")
|
logging.warning("torchaudio missing, ACE model will be broken")
|
||||||
|
|
||||||
|
|||||||
@ -1310,7 +1310,7 @@ class Omnigen2(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(omnigen2.LuminaTokenizer, omnigen2.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(omnigen2.Omnigen2Tokenizer, omnigen2.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2]
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from av.container import InputContainer
|
from av.container import InputContainer
|
||||||
from av.subtitles.stream import SubtitleStream
|
from av.subtitles.stream import SubtitleStream # pylint: disable=no-name-in-module
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from comfy_api.input import AudioInput
|
from comfy_api.input import AudioInput
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from comfy_api_nodes.apis.client import (
|
|||||||
UploadRequest,
|
UploadRequest,
|
||||||
UploadResponse,
|
UploadResponse,
|
||||||
)
|
)
|
||||||
from server import PromptServer
|
from comfy.cmd.server import PromptServer
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@ -103,7 +103,7 @@ from urllib.parse import urljoin, urlparse
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import uuid # For generating unique operation IDs
|
import uuid # For generating unique operation IDs
|
||||||
|
|
||||||
from server import PromptServer
|
from comfy.cmd.server import PromptServer
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy import utils
|
from comfy import utils
|
||||||
from . import request_logger
|
from . import request_logger
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import os
|
|||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import folder_paths
|
from comfy.cmd import folder_paths
|
||||||
|
|
||||||
# Get the logger instance
|
# Get the logger instance
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@ -32,7 +32,7 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
import base64
|
import base64
|
||||||
import time
|
import time
|
||||||
from server import PromptServer
|
from comfy.cmd.server import PromptServer
|
||||||
|
|
||||||
|
|
||||||
def convert_mask_to_image(mask: torch.Tensor):
|
def convert_mask_to_image(mask: torch.Tensor):
|
||||||
|
|||||||
@ -9,9 +9,9 @@ from typing import Optional, Literal
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import folder_paths
|
from comfy.cmd import folder_paths
|
||||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||||
from server import PromptServer
|
from comfy.cmd.server import PromptServer
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
GeminiContent,
|
GeminiContent,
|
||||||
GeminiGenerateContentRequest,
|
GeminiGenerateContentRequest,
|
||||||
@ -406,7 +406,7 @@ class GeminiInputFiles(ComfyNodeABC):
|
|||||||
|
|
||||||
def create_file_part(self, file_path: str) -> GeminiPart:
|
def create_file_part(self, file_path: str) -> GeminiPart:
|
||||||
mime_type = (
|
mime_type = (
|
||||||
GeminiMimeType.pdf
|
GeminiMimeType.application_pdf
|
||||||
if file_path.endswith(".pdf")
|
if file_path.endswith(".pdf")
|
||||||
else GeminiMimeType.text_plain
|
else GeminiMimeType.text_plain
|
||||||
)
|
)
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from comfy_api_nodes.apinode_utils import (
|
|||||||
bytesio_to_image_tensor,
|
bytesio_to_image_tensor,
|
||||||
resize_mask_to_image,
|
resize_mask_to_image,
|
||||||
)
|
)
|
||||||
from server import PromptServer
|
from comfy.cmd.server import PromptServer
|
||||||
|
|
||||||
V1_V1_RES_MAP = {
|
V1_V1_RES_MAP = {
|
||||||
"Auto":"AUTO",
|
"Auto":"AUTO",
|
||||||
|
|||||||
@ -36,7 +36,7 @@ from comfy_api_nodes.apinode_utils import (
|
|||||||
process_image_response,
|
process_image_response,
|
||||||
validate_string,
|
validate_string,
|
||||||
)
|
)
|
||||||
from server import PromptServer
|
from comfy.cmd.server import PromptServer
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -24,7 +24,7 @@ from comfy_api_nodes.apinode_utils import (
|
|||||||
upload_images_to_comfyapi,
|
upload_images_to_comfyapi,
|
||||||
validate_string,
|
validate_string,
|
||||||
)
|
)
|
||||||
from server import PromptServer
|
from comfy.cmd.server import PromptServer
|
||||||
|
|
||||||
|
|
||||||
I2V_AVERAGE_DURATION = 114
|
I2V_AVERAGE_DURATION = 114
|
||||||
|
|||||||
@ -11,8 +11,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||||
from server import PromptServer
|
from comfy.cmd.server import PromptServer
|
||||||
import folder_paths
|
from comfy.cmd import folder_paths
|
||||||
|
|
||||||
|
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
from comfy_extras.nodes_images import SVG # Added
|
from comfy_extras.nodes.nodes_images import SVG # Added
|
||||||
from comfy.comfy_types.node_typing import IO
|
from comfy.comfy_types.node_typing import IO
|
||||||
from comfy_api_nodes.apis.recraft_api import (
|
from comfy_api_nodes.apis.recraft_api import (
|
||||||
RecraftImageGenerationRequest,
|
RecraftImageGenerationRequest,
|
||||||
@ -30,7 +30,7 @@ from comfy_api_nodes.apinode_utils import (
|
|||||||
resize_mask_to_image,
|
resize_mask_to_image,
|
||||||
validate_string,
|
validate_string,
|
||||||
)
|
)
|
||||||
from server import PromptServer
|
from comfy.cmd.server import PromptServer
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|||||||
@ -8,7 +8,7 @@ Rodin API docs: https://developer.hyper3d.ai/
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
from comfy.comfy_types.node_typing import IO
|
from comfy.comfy_types.node_typing import IO
|
||||||
import folder_paths as comfy_paths
|
from comfy.cmd import folder_paths as comfy_paths
|
||||||
import requests
|
import requests
|
||||||
import os
|
import os
|
||||||
import datetime
|
import datetime
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from folder_paths import get_output_directory
|
from comfy.cmd.folder_paths import get_output_directory
|
||||||
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||||
from comfy.comfy_types.node_typing import IO
|
from comfy.comfy_types.node_typing import IO
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
|
|||||||
@ -113,9 +113,9 @@ def extract_node_configuration(path) -> Optional[PyProjectConfig]:
|
|||||||
project_data = raw_settings.project
|
project_data = raw_settings.project
|
||||||
|
|
||||||
tool_data = raw_settings.tool
|
tool_data = raw_settings.tool
|
||||||
comfy_data = tool_data.get("comfy", {}) if tool_data else {}
|
comfy_data = tool_data.get("comfy", {}) if tool_data else {} # pylint: disable=no-member
|
||||||
|
|
||||||
dependencies = project_data.get("dependencies", [])
|
dependencies = project_data.get("dependencies", []) # pylint: disable=no-member
|
||||||
supported_comfyui_frontend_version = ""
|
supported_comfyui_frontend_version = ""
|
||||||
for dep in dependencies:
|
for dep in dependencies:
|
||||||
if isinstance(dep, str) and dep.startswith("comfyui-frontend-package"):
|
if isinstance(dep, str) and dep.startswith("comfyui-frontend-package"):
|
||||||
@ -124,7 +124,7 @@ def extract_node_configuration(path) -> Optional[PyProjectConfig]:
|
|||||||
|
|
||||||
supported_comfyui_version = comfy_data.get("requires-comfyui", "")
|
supported_comfyui_version = comfy_data.get("requires-comfyui", "")
|
||||||
|
|
||||||
classifiers = project_data.get('classifiers', [])
|
classifiers = project_data.get('classifiers', []) # pylint: disable=no-member
|
||||||
supported_os = validate_and_extract_os_classifiers(classifiers)
|
supported_os = validate_and_extract_os_classifiers(classifiers)
|
||||||
supported_accelerators = validate_and_extract_accelerator_classifiers(classifiers)
|
supported_accelerators = validate_and_extract_accelerator_classifiers(classifiers)
|
||||||
|
|
||||||
|
|||||||
@ -244,7 +244,7 @@ class ExecutionList(TopologicalSort):
|
|||||||
# This will execute the asynchronous function earlier, reducing the overall time.
|
# This will execute the asynchronous function earlier, reducing the overall time.
|
||||||
def is_async(node_id):
|
def is_async(node_id):
|
||||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
|
||||||
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
|
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
|
||||||
|
|
||||||
for node_id in node_list:
|
for node_id in node_list:
|
||||||
|
|||||||
@ -23,12 +23,13 @@ from comfy.distributed.process_pool_executor import ProcessPoolExecutor
|
|||||||
from comfy.distributed.server_stub import ServerStub
|
from comfy.distributed.server_stub import ServerStub
|
||||||
|
|
||||||
|
|
||||||
def create_test_prompt() -> QueueItem:
|
async def create_test_prompt() -> QueueItem:
|
||||||
from comfy.cmd.execution import validate_prompt
|
from comfy.cmd.execution import validate_prompt
|
||||||
|
|
||||||
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
|
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
|
||||||
validation_tuple = validate_prompt(prompt)
|
|
||||||
item_id = str(uuid.uuid4())
|
item_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
validation_tuple = await validate_prompt(item_id, prompt)
|
||||||
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
|
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
|
||||||
return QueueItem(queue_tuple, None)
|
return QueueItem(queue_tuple, None)
|
||||||
|
|
||||||
@ -55,7 +56,7 @@ async def test_basic_queue_worker(executor_factory: Callable[..., Executor]) ->
|
|||||||
# now submit some jobs
|
# now submit some jobs
|
||||||
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
|
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
|
||||||
await distributed_queue.init()
|
await distributed_queue.init()
|
||||||
queue_item = create_test_prompt()
|
queue_item = await create_test_prompt()
|
||||||
res: TaskInvocation = await distributed_queue.put_async(queue_item)
|
res: TaskInvocation = await distributed_queue.put_async(queue_item)
|
||||||
assert res.item_id == queue_item.prompt_id
|
assert res.item_id == queue_item.prompt_id
|
||||||
assert len(res.outputs) == 1
|
assert len(res.outputs) == 1
|
||||||
@ -73,7 +74,7 @@ async def test_distributed_prompt_queues_same_process():
|
|||||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||||
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
|
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
|
||||||
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
|
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
|
||||||
test_prompt = create_test_prompt()
|
test_prompt = await create_test_prompt()
|
||||||
test_prompt.completed = asyncio.Future()
|
test_prompt.completed = asyncio.Future()
|
||||||
|
|
||||||
frontend.put(test_prompt)
|
frontend.put(test_prompt)
|
||||||
@ -224,8 +225,8 @@ async def test_two_workers_distinct_requests():
|
|||||||
await queue.init()
|
await queue.init()
|
||||||
|
|
||||||
# Submit two prompts
|
# Submit two prompts
|
||||||
task1 = asyncio.create_task(queue.put_async(create_test_prompt()))
|
task1 = asyncio.create_task(queue.put_async(await create_test_prompt()))
|
||||||
task2 = asyncio.create_task(queue.put_async(create_test_prompt()))
|
task2 = asyncio.create_task(queue.put_async(await create_test_prompt()))
|
||||||
|
|
||||||
# Wait for tasks to complete
|
# Wait for tasks to complete
|
||||||
await asyncio.gather(task1, task2)
|
await asyncio.gather(task1, task2)
|
||||||
|
|||||||
@ -8,6 +8,8 @@ from comfy.cli_args import args
|
|||||||
from comfy.cmd.execution import validate_prompt
|
from comfy.cmd.execution import validate_prompt
|
||||||
from comfy.nodes_context import nodes
|
from comfy.nodes_context import nodes
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
valid_prompt: Final[dict] = {
|
valid_prompt: Final[dict] = {
|
||||||
"1": {
|
"1": {
|
||||||
"inputs": {
|
"inputs": {
|
||||||
@ -153,15 +155,15 @@ def disable_known_models():
|
|||||||
args.disable_known_models = original_value
|
args.disable_known_models = original_value
|
||||||
|
|
||||||
|
|
||||||
def test_validate_prompt_valid(mock_nodes):
|
async def test_validate_prompt_valid(mock_nodes):
|
||||||
prompt = valid_prompt
|
prompt = valid_prompt
|
||||||
result = validate_prompt(prompt)
|
result = await validate_prompt(str(uuid.uuid4()), prompt)
|
||||||
assert result.valid
|
assert result.valid
|
||||||
assert result.error is None
|
assert result.error is None
|
||||||
assert set(result.good_output_node_ids) == {"7"}
|
assert set(result.good_output_node_ids) == {"7"}
|
||||||
|
|
||||||
|
|
||||||
def test_validate_prompt_invalid_node(mock_nodes):
|
async def test_validate_prompt_invalid_node(mock_nodes):
|
||||||
prompt = {
|
prompt = {
|
||||||
"1": {
|
"1": {
|
||||||
"inputs": {},
|
"inputs": {},
|
||||||
@ -169,13 +171,13 @@ def test_validate_prompt_invalid_node(mock_nodes):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
result = validate_prompt(prompt)
|
result = await validate_prompt(str(uuid.uuid4()), prompt)
|
||||||
assert not result.valid
|
assert not result.valid
|
||||||
assert result.error["type"] == "invalid_prompt"
|
assert result.error["type"] == "invalid_prompt"
|
||||||
assert "NonExistentNode" in result.error["message"]
|
assert "NonExistentNode" in result.error["message"]
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_has_no_output(mock_nodes):
|
async def test_prompt_has_no_output(mock_nodes):
|
||||||
prompt = {
|
prompt = {
|
||||||
"1": {
|
"1": {
|
||||||
"inputs": {},
|
"inputs": {},
|
||||||
@ -183,12 +185,12 @@ def test_prompt_has_no_output(mock_nodes):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
result = validate_prompt(prompt)
|
result = await validate_prompt(str(uuid.uuid4()), prompt)
|
||||||
assert not result.valid
|
assert not result.valid
|
||||||
assert result.error["type"] == "prompt_no_outputs"
|
assert result.error["type"] == "prompt_no_outputs"
|
||||||
|
|
||||||
|
|
||||||
def test_validate_prompt_invalid_input_type(mock_nodes):
|
async def test_validate_prompt_invalid_input_type(mock_nodes):
|
||||||
prompt = valid_prompt.copy()
|
prompt = valid_prompt.copy()
|
||||||
prompt["1"] = {
|
prompt["1"] = {
|
||||||
"inputs": {
|
"inputs": {
|
||||||
@ -197,7 +199,7 @@ def test_validate_prompt_invalid_input_type(mock_nodes):
|
|||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
}
|
}
|
||||||
|
|
||||||
result = validate_prompt(prompt)
|
result = await validate_prompt(str(uuid.uuid4()), prompt)
|
||||||
assert not result.valid
|
assert not result.valid
|
||||||
assert result.error["type"] == "prompt_outputs_failed_validation"
|
assert result.error["type"] == "prompt_outputs_failed_validation"
|
||||||
assert result.node_errors["1"]["errors"][0]["type"] == "value_not_in_list"
|
assert result.node_errors["1"]["errors"][0]["type"] == "value_not_in_list"
|
||||||
@ -212,7 +214,7 @@ def test_validate_prompt_invalid_input_type(mock_nodes):
|
|||||||
("C:\\Windows\\Temp\\model.safetensors", "C:/Windows/Temp/model.safetensors"),
|
("C:\\Windows\\Temp\\model.safetensors", "C:/Windows/Temp/model.safetensors"),
|
||||||
("/home/user/models/model.safetensors", "/home/user/models/model.safetensors"),
|
("/home/user/models/model.safetensors", "/home/user/models/model.safetensors"),
|
||||||
])
|
])
|
||||||
def test_validate_prompt_path_variations(mock_nodes, disable_known_models, ckpt_name, known_model):
|
async def test_validate_prompt_path_variations(mock_nodes, disable_known_models, ckpt_name, known_model):
|
||||||
token = known_models.set([known_model])
|
token = known_models.set([known_model])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -224,23 +226,23 @@ def test_validate_prompt_path_variations(mock_nodes, disable_known_models, ckpt_
|
|||||||
"class_type": "CheckpointLoaderSimple",
|
"class_type": "CheckpointLoaderSimple",
|
||||||
}
|
}
|
||||||
|
|
||||||
result = validate_prompt(prompt)
|
result = await validate_prompt(str(uuid.uuid4()), prompt)
|
||||||
assert result.valid, f"Failed for ckpt_name: {ckpt_name}, known_model: {known_model}"
|
assert result.valid, f"Failed for ckpt_name: {ckpt_name}, known_model: {known_model}"
|
||||||
assert result.error is None, f"Error for ckpt_name: {ckpt_name}, known_model: {known_model}"
|
assert result.error is None, f"Error for ckpt_name: {ckpt_name}, known_model: {known_model}"
|
||||||
finally:
|
finally:
|
||||||
known_models.reset(token)
|
known_models.reset(token)
|
||||||
|
|
||||||
|
|
||||||
def test_validate_prompt_default_models(mock_nodes, disable_known_models):
|
async def test_validate_prompt_default_models(mock_nodes, disable_known_models):
|
||||||
prompt = valid_prompt.copy()
|
prompt = valid_prompt.copy()
|
||||||
prompt["1"]["inputs"]["ckpt_name"] = "model1.safetensors"
|
prompt["1"]["inputs"]["ckpt_name"] = "model1.safetensors"
|
||||||
|
|
||||||
result = validate_prompt(prompt)
|
result = await validate_prompt(str(uuid.uuid4()), prompt)
|
||||||
assert result.valid, "Failed for default model list"
|
assert result.valid, "Failed for default model list"
|
||||||
assert result.error is None, "Error for default model list"
|
assert result.error is None, "Error for default model list"
|
||||||
|
|
||||||
|
|
||||||
def test_validate_prompt_no_outputs(mock_nodes):
|
async def test_validate_prompt_no_outputs(mock_nodes):
|
||||||
prompt = {
|
prompt = {
|
||||||
"1": {
|
"1": {
|
||||||
"inputs": {
|
"inputs": {
|
||||||
@ -250,6 +252,6 @@ def test_validate_prompt_no_outputs(mock_nodes):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
result = validate_prompt(prompt)
|
result = await validate_prompt(str(uuid.uuid4()), prompt)
|
||||||
assert not result.valid
|
assert not result.valid
|
||||||
assert result.error["type"] == "prompt_no_outputs"
|
assert result.error["type"] == "prompt_no_outputs"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user