mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Improve tracing to propagate to backend workers correctly when using the API. Fix distributed tests.
This commit is contained in:
parent
75b63fce91
commit
f8fcfa6f08
@ -7,14 +7,19 @@ from asyncio import AbstractEventLoop
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
from opentelemetry import context
|
||||
from opentelemetry.trace import Span, Status, StatusCode
|
||||
|
||||
from ..api.components.schema.prompt import PromptDict
|
||||
from ..cli_args_types import Configuration
|
||||
from ..component_model.make_mutable import make_mutable
|
||||
from ..cmd.main_pre import tracer
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..component_model.make_mutable import make_mutable
|
||||
from ..distributed.server_stub import ServerStub
|
||||
|
||||
_server_stub_instance = ServerStub()
|
||||
|
||||
|
||||
class EmbeddedComfyClient:
|
||||
"""
|
||||
Embedded client for comfy executing prompts as a library.
|
||||
@ -110,30 +115,36 @@ class EmbeddedComfyClient:
|
||||
|
||||
await self._loop.run_in_executor(self._executor, create_executor_in_thread)
|
||||
|
||||
@tracer.start_as_current_span("Queue Prompt")
|
||||
async def queue_prompt(self,
|
||||
prompt: PromptDict | dict,
|
||||
prompt_id: Optional[str] = None,
|
||||
client_id: Optional[str] = None) -> dict:
|
||||
prompt_id = prompt_id or str(uuid.uuid4())
|
||||
client_id = client_id or self._progress_handler.client_id or None
|
||||
span_context = context.get_current()
|
||||
|
||||
def execute_prompt() -> dict:
|
||||
from ..cmd.execution import PromptExecutor, validate_prompt
|
||||
prompt_mut = make_mutable(prompt)
|
||||
validation_tuple = validate_prompt(prompt_mut)
|
||||
spam: Span
|
||||
with tracer.start_as_current_span("Execute Prompt", context=span_context) as span:
|
||||
from ..cmd.execution import PromptExecutor, validate_prompt
|
||||
prompt_mut = make_mutable(prompt)
|
||||
validation_tuple = validate_prompt(prompt_mut)
|
||||
|
||||
prompt_executor: PromptExecutor = self._prompt_executor
|
||||
prompt_executor: PromptExecutor = self._prompt_executor
|
||||
|
||||
if client_id is None:
|
||||
prompt_executor.server = _server_stub_instance
|
||||
else:
|
||||
prompt_executor.server = self._progress_handler
|
||||
if client_id is None:
|
||||
prompt_executor.server = _server_stub_instance
|
||||
else:
|
||||
prompt_executor.server = self._progress_handler
|
||||
|
||||
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
||||
execute_outputs=validation_tuple[2])
|
||||
if prompt_executor.success:
|
||||
return prompt_executor.outputs_ui
|
||||
else:
|
||||
raise RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages))
|
||||
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
||||
execute_outputs=validation_tuple[2])
|
||||
if prompt_executor.success:
|
||||
return prompt_executor.outputs_ui
|
||||
else:
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
error = RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages))
|
||||
span.record_exception(error)
|
||||
|
||||
return await self._loop.run_in_executor(self._executor, execute_prompt)
|
||||
|
||||
@ -9,22 +9,26 @@ import threading
|
||||
import traceback
|
||||
import typing
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing_extensions import TypedDict
|
||||
import torch
|
||||
import lazy_object_proxy
|
||||
|
||||
import lazy_object_proxy
|
||||
import torch
|
||||
from opentelemetry.trace import get_current_span, StatusCode, Status
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from .main_pre import tracer
|
||||
from .. import interruption
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from .. import model_management
|
||||
from ..nodes.package_typing import ExportedNodes
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||
from ..nodes.package import import_all_nodes_in_workspace
|
||||
from ..nodes.package_typing import ExportedNodes
|
||||
|
||||
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
|
||||
# various functions that are declared here. It should have been a context in the first place.
|
||||
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
|
||||
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extra_data=None):
|
||||
if extra_data is None:
|
||||
extra_data = {}
|
||||
@ -139,6 +143,7 @@ def format_value(x):
|
||||
return str(x)
|
||||
|
||||
|
||||
@tracer.start_as_current_span("Recursive Execute")
|
||||
def recursive_execute(server: ExecutorToClientProgress,
|
||||
prompt,
|
||||
outputs,
|
||||
@ -324,6 +329,10 @@ class PromptExecutor:
|
||||
self.server.send_sync(event, data, self.server.client_id)
|
||||
|
||||
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
|
||||
current_span = get_current_span()
|
||||
current_span.set_status(Status(StatusCode.ERROR))
|
||||
current_span.record_exception(ex)
|
||||
|
||||
node_id = error["node_id"]
|
||||
class_type = prompt[node_id]["class_type"]
|
||||
|
||||
@ -430,7 +439,6 @@ class PromptExecutor:
|
||||
# This call shouldn't raise anything if there's an error deep in
|
||||
# the actual SD code, instead it will report the node where the
|
||||
# error was raised
|
||||
# todo: if we're using a distributed queue, we must wrap the server instance to correctly communicate back to the client via the exchange
|
||||
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id,
|
||||
extra_data, executed, prompt_id, self.outputs_ui,
|
||||
self.object_storage)
|
||||
@ -689,6 +697,7 @@ class ValidationErrorDict(TypedDict):
|
||||
ValidationTuple = typing.Tuple[bool, Optional[ValidationErrorDict], typing.List[str], Union[dict, list]]
|
||||
|
||||
|
||||
@tracer.start_as_current_span("Validate Prompt")
|
||||
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
|
||||
@ -13,17 +13,16 @@ import uuid
|
||||
from asyncio import Future, AbstractEventLoop
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from posixpath import join as urljoin
|
||||
from typing import List, Optional, Dict
|
||||
from urllib.parse import quote, urlencode
|
||||
from posixpath import join as urljoin
|
||||
|
||||
from can_ada import URL, parse as urlparse
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from aiohttp import web
|
||||
from can_ada import URL, parse as urlparse
|
||||
from pkg_resources import resource_filename
|
||||
from typing_extensions import NamedTuple
|
||||
|
||||
@ -36,12 +35,13 @@ from ..cli_args import args
|
||||
from ..client.client_types import Output, FileOutput
|
||||
from ..cmd import execution
|
||||
from ..cmd import folder_paths
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..component_model.file_output_path import file_output_path
|
||||
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation
|
||||
from ..digest import digest
|
||||
from ..nodes.package_typing import ExportedNodes
|
||||
from ..images import open_image
|
||||
from ..nodes.package_typing import ExportedNodes
|
||||
|
||||
|
||||
class HeuristicPath(NamedTuple):
|
||||
@ -95,7 +95,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
self.user_manager = UserManager()
|
||||
# todo: this is probably read by custom nodes elsewhere
|
||||
self.supports: List[str] = ["custom_nodes_from_web"]
|
||||
self.prompt_queue: execution.AbstractPromptQueue | None = None
|
||||
self.prompt_queue: AbstractPromptQueue | AsyncAbstractPromptQueue | None = None
|
||||
self.loop: AbstractEventLoop = loop
|
||||
self.messages: asyncio.Queue = asyncio.Queue()
|
||||
self.number: int = 0
|
||||
@ -599,22 +599,27 @@ class PromptServer(ExecutorToClientProgress):
|
||||
return web.Response(status=400, content_type="application/json", body=json.dumps(valid[1]))
|
||||
|
||||
# convert a valid prompt to the queue tuple this expects
|
||||
completed: Future[TaskInvocation | dict] = self.loop.create_future()
|
||||
number = self.number
|
||||
self.number += 1
|
||||
self.prompt_queue.put(
|
||||
QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]),
|
||||
completed=completed))
|
||||
|
||||
try:
|
||||
await completed
|
||||
except Exception as ex:
|
||||
return web.Response(body=str(ex), status=503)
|
||||
# expect a single image
|
||||
result: TaskInvocation | dict = completed.result()
|
||||
completed: Future[TaskInvocation | dict] = self.loop.create_future()
|
||||
item = QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]), completed=completed)
|
||||
|
||||
if hasattr(self.prompt_queue, "put_async") or isinstance(self.prompt_queue, AsyncAbstractPromptQueue):
|
||||
# this enables span propagation seamlessly
|
||||
result = await self.prompt_queue.put_async(item)
|
||||
if result is None:
|
||||
return web.Response(body="the queue is shutting down", status=503)
|
||||
else:
|
||||
try:
|
||||
self.prompt_queue.put(item)
|
||||
await completed
|
||||
except Exception as ex:
|
||||
return web.Response(body=str(ex), status=503)
|
||||
# expect a single image
|
||||
result: TaskInvocation | dict = completed.result()
|
||||
outputs_dict: Dict[str, Output] = result.outputs if isinstance(result, TaskInvocation) else result
|
||||
# find images and read them
|
||||
|
||||
output_images: List[FileOutput] = []
|
||||
for node_id, node in outputs_dict.items():
|
||||
images: List[FileOutput] = []
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import typing
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus
|
||||
from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus, TaskInvocation
|
||||
|
||||
|
||||
class AbstractPromptQueue(metaclass=ABCMeta):
|
||||
@ -118,3 +118,13 @@ class AbstractPromptQueue(metaclass=ABCMeta):
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncAbstractPromptQueue(AbstractPromptQueue):
|
||||
@abstractmethod
|
||||
async def put_async(self, queue_item) -> TaskInvocation | None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_async(self, timeout: float | None = None) -> typing.Optional[typing.Tuple[QueueTuple, str]]:
|
||||
pass
|
||||
@ -18,13 +18,14 @@ from .distributed_types import RpcRequest, RpcReply
|
||||
from .history import History
|
||||
from .server_stub import ServerStub
|
||||
from ..auth.permissions import jwt_decode
|
||||
from ..cmd.main_pre import tracer
|
||||
from ..cmd.server import PromptServer
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
|
||||
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
|
||||
|
||||
|
||||
class DistributedPromptQueue(AbstractPromptQueue):
|
||||
class DistributedPromptQueue(AsyncAbstractPromptQueue):
|
||||
"""
|
||||
A distributed prompt queue for the ComfyUI web client and single-threaded worker.
|
||||
"""
|
||||
@ -39,12 +40,13 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
async def progress(self, event: SendSyncEvent, data: SendSyncData, sid: Optional[str]) -> None:
|
||||
self._caller_server.send_sync(event, data, sid=sid)
|
||||
|
||||
async def put_async(self, queue_item: QueueItem):
|
||||
@tracer.start_as_current_span("Put Async")
|
||||
async def put_async(self, queue_item: QueueItem) -> TaskInvocation | None:
|
||||
assert self._is_caller
|
||||
assert self._rpc is not None
|
||||
|
||||
if self._closing:
|
||||
return
|
||||
return None
|
||||
self._caller_local_in_progress[queue_item.prompt_id] = queue_item
|
||||
if self._caller_server is not None:
|
||||
self._caller_server.queue_updated()
|
||||
@ -88,9 +90,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
# if we have a completer, propoagate the exception to it
|
||||
if queue_item.completed is not None:
|
||||
queue_item.completed.set_exception(e)
|
||||
else:
|
||||
# otherwise, this should raise in the event loop, which I suppose isn't handled
|
||||
raise e
|
||||
raise e
|
||||
finally:
|
||||
self._caller_local_in_progress.pop(queue_item.prompt_id)
|
||||
if self._caller_server is not None:
|
||||
|
||||
@ -12,6 +12,7 @@ from aiormq import AMQPConnectionError
|
||||
from .distributed_progress import DistributedExecutorToClientProgress
|
||||
from .distributed_types import RpcRequest, RpcReply
|
||||
from ..client.embedded_comfy_client import EmbeddedComfyClient
|
||||
from ..cmd.main_pre import tracer
|
||||
from ..component_model.queue_types import ExecutionStatus
|
||||
|
||||
|
||||
@ -32,6 +33,7 @@ class DistributedPromptWorker:
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
self._embedded_comfy_client = embedded_comfy_client
|
||||
|
||||
@tracer.start_as_current_span("Do Work Item")
|
||||
async def _do_work_item(self, request: dict) -> dict:
|
||||
await self.on_will_complete_work_item(request)
|
||||
try:
|
||||
@ -55,6 +57,7 @@ class DistributedPromptWorker:
|
||||
await self.on_did_complete_work_item(request_obj, reply)
|
||||
return asdict(reply)
|
||||
|
||||
@tracer.start_as_current_span("Initialize Prompt Worker")
|
||||
async def init(self):
|
||||
await self._exit_stack.__aenter__()
|
||||
try:
|
||||
|
||||
@ -9,8 +9,11 @@ import types
|
||||
from functools import reduce
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
from opentelemetry.trace import Span, Status, StatusCode
|
||||
from pkg_resources import resource_filename
|
||||
|
||||
from .package_typing import ExportedNodes
|
||||
from ..cmd.main_pre import tracer
|
||||
|
||||
_comfy_nodes: ExportedNodes = ExportedNodes()
|
||||
|
||||
@ -46,29 +49,33 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
||||
else:
|
||||
# Iterate through all the submodules
|
||||
for _, name, is_pkg in pkgutil.iter_modules(module.__path__):
|
||||
full_name = module.__name__ + "." + name
|
||||
time_before = time.perf_counter()
|
||||
success = True
|
||||
|
||||
if full_name.endswith(".disabled"):
|
||||
continue
|
||||
try:
|
||||
submodule = importlib.import_module(full_name)
|
||||
# Recursively call the function if it's a package
|
||||
exported_nodes.update(
|
||||
_import_and_enumerate_nodes_in_module(submodule, print_import_times=print_import_times,
|
||||
depth=depth - 1))
|
||||
except KeyboardInterrupt as interrupted:
|
||||
raise interrupted
|
||||
except Exception as x:
|
||||
if isinstance(x, AttributeError):
|
||||
potential_path_error: AttributeError = x
|
||||
if potential_path_error.name == '__path__':
|
||||
continue
|
||||
logging.error(f"{full_name} import failed", exc_info=x)
|
||||
success = False
|
||||
exceptions.append(x)
|
||||
timings.append((time.perf_counter() - time_before, full_name, success))
|
||||
span: Span
|
||||
with tracer.start_as_current_span("Load Node") as span:
|
||||
full_name = module.__name__ + "." + name
|
||||
time_before = time.perf_counter()
|
||||
success = True
|
||||
span.set_attribute("full_name", full_name)
|
||||
if full_name.endswith(".disabled"):
|
||||
continue
|
||||
try:
|
||||
submodule = importlib.import_module(full_name)
|
||||
# Recursively call the function if it's a package
|
||||
new_nodes = _import_and_enumerate_nodes_in_module(submodule, print_import_times=print_import_times, depth=depth - 1)
|
||||
span.set_attribute("new_nodes.length", len(new_nodes))
|
||||
exported_nodes.update(new_nodes)
|
||||
except KeyboardInterrupt as interrupted:
|
||||
raise interrupted
|
||||
except Exception as x:
|
||||
if isinstance(x, AttributeError):
|
||||
potential_path_error: AttributeError = x
|
||||
if potential_path_error.name == '__path__':
|
||||
continue
|
||||
logging.error(f"{full_name} import failed", exc_info=x)
|
||||
success = False
|
||||
exceptions.append(x)
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
span.record_exception(x)
|
||||
timings.append((time.perf_counter() - time_before, full_name, success))
|
||||
|
||||
if print_import_times and len(timings) > 0 or any(not success for (_, _, success) in timings):
|
||||
for (duration, module_name, success) in sorted(timings):
|
||||
@ -81,6 +88,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
||||
return exported_nodes
|
||||
|
||||
|
||||
@tracer.start_as_current_span("Import All Nodes In Workspace")
|
||||
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
|
||||
# now actually import the nodes, to improve control of node loading order
|
||||
from comfy_extras import nodes as comfy_extras_nodes
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os.path
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
@ -17,7 +18,7 @@ from .component_model.executor_types import ExecutorToClientProgress
|
||||
from .component_model.queue_types import BinaryEventTypes
|
||||
|
||||
PROGRESS_BAR_ENABLED = True
|
||||
PROGRESS_BAR_HOOK = None
|
||||
_progress_bar_hook = threading.local()
|
||||
|
||||
|
||||
|
||||
@ -479,7 +480,7 @@ def hijack_progress(server: ExecutorToClientProgress):
|
||||
if preview_image is not None:
|
||||
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
|
||||
|
||||
set_progress_bar_global_hook(hook)
|
||||
_progress_bar_hook.hook = hook
|
||||
|
||||
|
||||
def set_progress_bar_enabled(enabled):
|
||||
@ -491,11 +492,6 @@ def get_progress_bar_enabled() -> bool:
|
||||
return PROGRESS_BAR_ENABLED
|
||||
|
||||
|
||||
def set_progress_bar_global_hook(function):
|
||||
global PROGRESS_BAR_HOOK
|
||||
PROGRESS_BAR_HOOK = function
|
||||
|
||||
|
||||
class _DisabledProgressBar:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
@ -509,10 +505,10 @@ class _DisabledProgressBar:
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self, total: float):
|
||||
global PROGRESS_BAR_HOOK
|
||||
global _progress_bar_hook
|
||||
self.total: float = total
|
||||
self.current: float = 0.0
|
||||
self.hook = PROGRESS_BAR_HOOK
|
||||
self.hook = _progress_bar_hook.hook if hasattr(_progress_bar_hook, "hook") else None
|
||||
|
||||
def update_absolute(self, value, total=None, preview=None):
|
||||
if total is not None:
|
||||
|
||||
@ -64,7 +64,7 @@ def comfy_background_server(args_pytest):
|
||||
if not success:
|
||||
raise Exception("Failed to start background server")
|
||||
yield
|
||||
p.kill()
|
||||
p.terminate()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
||||
@ -1,27 +1,26 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
import subprocess
|
||||
import socket
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
import requests
|
||||
from testcontainers.rabbitmq import RabbitMqContainer
|
||||
|
||||
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
||||
from comfy.distributed.server_stub import ServerStub
|
||||
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
||||
from comfy.component_model.make_mutable import make_mutable
|
||||
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
|
||||
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
|
||||
from testcontainers.rabbitmq import RabbitMqContainer
|
||||
from comfy.distributed.server_stub import ServerStub
|
||||
|
||||
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
||||
os.environ["TC_HOST"] = "localhost"
|
||||
@ -55,8 +54,7 @@ async def test_basic_queue_worker() -> None:
|
||||
# this unfortunately does a bunch of initialization on the test thread
|
||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
# now submit some jobs
|
||||
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
|
||||
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
|
||||
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
|
||||
await distributed_queue.init()
|
||||
queue_item = create_test_prompt()
|
||||
res: TaskInvocation = await distributed_queue.put_async(queue_item)
|
||||
@ -74,10 +72,8 @@ async def test_distributed_prompt_queues_same_process():
|
||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||
|
||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
|
||||
connection_uri=connection_uri) as frontend:
|
||||
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False,
|
||||
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker:
|
||||
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
|
||||
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
|
||||
test_prompt = create_test_prompt()
|
||||
test_prompt.completed = asyncio.Future()
|
||||
|
||||
@ -117,6 +113,7 @@ async def test_frontend_backend_workers():
|
||||
frontend_command = [
|
||||
"comfyui",
|
||||
"--listen=0.0.0.0",
|
||||
"--port=9001",
|
||||
"--cpu",
|
||||
"--distributed-queue-frontend",
|
||||
f"--distributed-queue-connection-uri={connection_uri}",
|
||||
@ -125,6 +122,7 @@ async def test_frontend_backend_workers():
|
||||
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||
backend_command = [
|
||||
"comfyui-worker",
|
||||
"--port=9002",
|
||||
f"--distributed-queue-connection-uri={connection_uri}",
|
||||
]
|
||||
|
||||
@ -145,7 +143,7 @@ async def test_frontend_backend_workers():
|
||||
client = AsyncRemoteComfyClient(server_address=server_address)
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
|
||||
png_image_bytes = await client.queue_prompt(prompt)
|
||||
assert len(png_image_bytes) > 1000
|
||||
assert len(png_image_bytes) > 1000, "expected an image, but got nothing"
|
||||
finally:
|
||||
for process in processes_to_close:
|
||||
process.terminate()
|
||||
|
||||
@ -24,8 +24,7 @@ async def test_cuda_memory_usage():
|
||||
assert memory_after_workflow > starting_memory, "Expected CUDA memory to increase after running the workflow"
|
||||
|
||||
ending_memory = torch.cuda.memory_allocated(device)
|
||||
assert abs(
|
||||
ending_memory - starting_memory) < 1e7, "Expected CUDA memory to return close to starting memory after cleanup"
|
||||
assert abs(ending_memory - starting_memory) < 1e7, "Expected CUDA memory to return close to starting memory after cleanup"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Loading…
Reference in New Issue
Block a user